forked from phoenix/litellm-mirror
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
18bf68298f
237 changed files with 25040 additions and 3688 deletions
|
@ -1,4 +1,4 @@
|
||||||
version: 2.1
|
version: 4.3.4
|
||||||
jobs:
|
jobs:
|
||||||
local_testing:
|
local_testing:
|
||||||
docker:
|
docker:
|
||||||
|
@ -189,7 +189,7 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
docker run -d \
|
docker run -d \
|
||||||
-p 4000:4000 \
|
-p 4000:4000 \
|
||||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
-e REDIS_HOST=$REDIS_HOST \
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
@ -199,6 +199,7 @@ jobs:
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
|
-e AUTO_INFER_REGION=True \
|
||||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||||
|
@ -209,9 +210,7 @@ jobs:
|
||||||
my-app:latest \
|
my-app:latest \
|
||||||
--config /app/config.yaml \
|
--config /app/config.yaml \
|
||||||
--port 4000 \
|
--port 4000 \
|
||||||
--num_workers 8 \
|
|
||||||
--detailed_debug \
|
--detailed_debug \
|
||||||
--run_gunicorn \
|
|
||||||
- run:
|
- run:
|
||||||
name: Install curl and dockerize
|
name: Install curl and dockerize
|
||||||
command: |
|
command: |
|
||||||
|
@ -226,7 +225,7 @@ jobs:
|
||||||
background: true
|
background: true
|
||||||
- run:
|
- run:
|
||||||
name: Wait for app to be ready
|
name: Wait for app to be ready
|
||||||
command: dockerize -wait http://localhost:4000 -timeout 1m
|
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
|
|
51
.devcontainer/devcontainer.json
Normal file
51
.devcontainer/devcontainer.json
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"name": "Python 3.11",
|
||||||
|
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||||
|
"image": "mcr.microsoft.com/devcontainers/python:3.11-bookworm",
|
||||||
|
// https://github.com/devcontainers/images/tree/main/src/python
|
||||||
|
// https://mcr.microsoft.com/en-us/product/devcontainers/python/tags
|
||||||
|
|
||||||
|
// "build": {
|
||||||
|
// "dockerfile": "Dockerfile",
|
||||||
|
// "context": ".."
|
||||||
|
// },
|
||||||
|
|
||||||
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
|
// "features": {},
|
||||||
|
|
||||||
|
// Configure tool-specific properties.
|
||||||
|
"customizations": {
|
||||||
|
// Configure properties specific to VS Code.
|
||||||
|
"vscode": {
|
||||||
|
"settings": {},
|
||||||
|
"extensions": [
|
||||||
|
"ms-python.python",
|
||||||
|
"ms-python.vscode-pylance",
|
||||||
|
"GitHub.copilot",
|
||||||
|
"GitHub.copilot-chat"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
|
"forwardPorts": [4000],
|
||||||
|
|
||||||
|
"containerEnv": {
|
||||||
|
"LITELLM_LOG": "DEBUG"
|
||||||
|
},
|
||||||
|
|
||||||
|
// Use 'portsAttributes' to set default properties for specific forwarded ports.
|
||||||
|
// More info: https://containers.dev/implementors/json_reference/#port-attributes
|
||||||
|
"portsAttributes": {
|
||||||
|
"4000": {
|
||||||
|
"label": "LiteLLM Server",
|
||||||
|
"onAutoForward": "notify"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// More info: https://aka.ms/dev-containers-non-root.
|
||||||
|
// "remoteUser": "litellm",
|
||||||
|
|
||||||
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
|
"postCreateCommand": "pipx install poetry && poetry install -E extra_proxy -E proxy"
|
||||||
|
}
|
10
.git-blame-ignore-revs
Normal file
10
.git-blame-ignore-revs
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# Add the commit hash of any commit you want to ignore in `git blame` here.
|
||||||
|
# One commit hash per line.
|
||||||
|
#
|
||||||
|
# The GitHub Blame UI will use this file automatically!
|
||||||
|
#
|
||||||
|
# Run this command to always ignore formatting commits in `git blame`
|
||||||
|
# git config blame.ignoreRevsFile .git-blame-ignore-revs
|
||||||
|
|
||||||
|
# Update pydantic code to fix warnings (GH-3600)
|
||||||
|
876840e9957bc7e9f7d6a2b58c4d7c53dad16481
|
22
.github/pull_request_template.md
vendored
22
.github/pull_request_template.md
vendored
|
@ -1,6 +1,3 @@
|
||||||
<!-- This is just examples. You can remove all items if you want. -->
|
|
||||||
<!-- Please remove all comments. -->
|
|
||||||
|
|
||||||
## Title
|
## Title
|
||||||
|
|
||||||
<!-- e.g. "Implement user authentication feature" -->
|
<!-- e.g. "Implement user authentication feature" -->
|
||||||
|
@ -18,7 +15,6 @@
|
||||||
🐛 Bug Fix
|
🐛 Bug Fix
|
||||||
🧹 Refactoring
|
🧹 Refactoring
|
||||||
📖 Documentation
|
📖 Documentation
|
||||||
💻 Development Environment
|
|
||||||
🚄 Infrastructure
|
🚄 Infrastructure
|
||||||
✅ Test
|
✅ Test
|
||||||
|
|
||||||
|
@ -26,22 +22,8 @@
|
||||||
|
|
||||||
<!-- List of changes -->
|
<!-- List of changes -->
|
||||||
|
|
||||||
## Testing
|
## [REQUIRED] Testing - Attach a screenshot of any new tests passing locall
|
||||||
|
If UI changes, send a screenshot/GIF of working UI fixes
|
||||||
|
|
||||||
<!-- Test procedure -->
|
<!-- Test procedure -->
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
<!-- Test results -->
|
|
||||||
|
|
||||||
<!-- Points to note for the reviewer, consultation content, concerns -->
|
|
||||||
|
|
||||||
## Pre-Submission Checklist (optional but appreciated):
|
|
||||||
|
|
||||||
- [ ] I have included relevant documentation updates (stored in /docs/my-website)
|
|
||||||
|
|
||||||
## OS Tests (optional but appreciated):
|
|
||||||
|
|
||||||
- [ ] Tested on Windows
|
|
||||||
- [ ] Tested on MacOS
|
|
||||||
- [ ] Tested on Linux
|
|
||||||
|
|
19
.github/workflows/interpret_load_test.py
vendored
19
.github/workflows/interpret_load_test.py
vendored
|
@ -64,6 +64,11 @@ if __name__ == "__main__":
|
||||||
) # Replace with your repository's username and name
|
) # Replace with your repository's username and name
|
||||||
latest_release = repo.get_latest_release()
|
latest_release = repo.get_latest_release()
|
||||||
print("got latest release: ", latest_release)
|
print("got latest release: ", latest_release)
|
||||||
|
print(latest_release.title)
|
||||||
|
print(latest_release.tag_name)
|
||||||
|
|
||||||
|
release_version = latest_release.title
|
||||||
|
|
||||||
print("latest release body: ", latest_release.body)
|
print("latest release body: ", latest_release.body)
|
||||||
print("markdown table: ", markdown_table)
|
print("markdown table: ", markdown_table)
|
||||||
|
|
||||||
|
@ -74,8 +79,22 @@ if __name__ == "__main__":
|
||||||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||||
existing_release_body = latest_release.body[:start_index]
|
existing_release_body = latest_release.body[:start_index]
|
||||||
|
|
||||||
|
docker_run_command = f"""
|
||||||
|
\n\n
|
||||||
|
## Docker Run LiteLLM Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run \\
|
||||||
|
-e STORE_MODEL_IN_DB=True \\
|
||||||
|
-p 4000:4000 \\
|
||||||
|
ghcr.io/berriai/litellm:main-{release_version}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
print("docker run command: ", docker_run_command)
|
||||||
|
|
||||||
new_release_body = (
|
new_release_body = (
|
||||||
existing_release_body
|
existing_release_body
|
||||||
|
+ docker_run_command
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
||||||
.venv
|
.venv
|
||||||
.env
|
.env
|
||||||
|
litellm/proxy/myenv/*
|
||||||
litellm_uuid.txt
|
litellm_uuid.txt
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_new_secret_config.yaml
|
litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
|
litellm/proxy/myenv/bin/activate
|
||||||
|
litellm/proxy/myenv/bin/Activate.ps1
|
||||||
|
myenv/*
|
|
@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||||
|
|
187
cookbook/liteLLM_clarifai_Demo.ipynb
vendored
Normal file
187
cookbook/liteLLM_clarifai_Demo.ipynb
vendored
Normal file
|
@ -0,0 +1,187 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# LiteLLM Clarifai \n",
|
||||||
|
"This notebook walks you through on how to use liteLLM integration of Clarifai and call LLM model from clarifai with response in openAI output format."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Pre-Requisites"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#install necessary packages\n",
|
||||||
|
"!pip install litellm\n",
|
||||||
|
"!pip install clarifai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To obtain Clarifai Personal Access Token follow the steps mentioned in the [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## Set Clarifai Credentials\n",
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"CLARIFAI_API_KEY\"]= \"YOUR_CLARIFAI_PAT\" # Clarifai PAT"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Mistral-large"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import litellm\n",
|
||||||
|
"\n",
|
||||||
|
"litellm.set_verbose=False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Mistral large response : ModelResponse(id='chatcmpl-6eed494d-7ae2-4870-b9c2-6a64d50a6151', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\"In the grand tapestry of time, where tales unfold,\\nLies the chronicle of ages, a sight to behold.\\nA tale of empires rising, and kings of old,\\nOf civilizations lost, and stories untold.\\n\\nOnce upon a yesterday, in a time so vast,\\nHumans took their first steps, casting shadows in the past.\\nFrom the cradle of mankind, a journey they embarked,\\nThrough stone and bronze and iron, their skills they sharpened and marked.\\n\\nEgyptians built pyramids, reaching for the skies,\\nWhile Greeks sought wisdom, truth, in philosophies that lie.\\nRoman legions marched, their empire to expand,\\nAnd in the East, the Silk Road joined the world, hand in hand.\\n\\nThe Middle Ages came, with knights in shining armor,\\nFeudal lords and serfs, a time of both clamor and calm order.\\nThen Renaissance bloomed, like a flower in the sun,\\nA rebirth of art and science, a new age had begun.\\n\\nAcross the vast oceans, explorers sailed with courage bold,\\nDiscovering new lands, stories of adventure, untold.\\nIndustrial Revolution churned, progress in its wake,\\nMachines and factories, a whole new world to make.\\n\\nTwo World Wars raged, a testament to man's strife,\\nYet from the ashes rose hope, a renewed will for life.\\nInto the modern era, technology took flight,\\nConnecting every corner, bathed in digital light.\\n\\nHistory, a symphony, a melody of time,\\nA testament to human will, resilience so sublime.\\nIn every page, a lesson, in every tale, a guide,\\nFor understanding our past, shapes our future's tide.\", role='assistant'))], created=1713896412, model='https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=13, completion_tokens=338, total_tokens=351))\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from litellm import completion\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||||
|
"response=completion(\n",
|
||||||
|
" model=\"clarifai/mistralai.completion.mistral-large\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Mistral large response : {response}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Claude-2.1 "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Claude-2.1 response : ModelResponse(id='chatcmpl-d126c919-4db4-4aa3-ac8f-7edea41e0b93', choices=[Choices(finish_reason='stop', index=1, message=Message(content=\" Here's a poem I wrote about history:\\n\\nThe Tides of Time\\n\\nThe tides of time ebb and flow,\\nCarrying stories of long ago.\\nFigures and events come into light,\\nShaping the future with all their might.\\n\\nKingdoms rise, empires fall, \\nLeaving traces that echo down every hall.\\nRevolutions bring change with a fiery glow,\\nToppling structures from long ago.\\n\\nExplorers traverse each ocean and land,\\nSeeking treasures they don't understand.\\nWhile artists and writers try to make their mark,\\nHoping their works shine bright in the dark.\\n\\nThe cycle repeats again and again,\\nAs humanity struggles to learn from its pain.\\nThough the players may change on history's stage,\\nThe themes stay the same from age to age.\\n\\nWar and peace, life and death,\\nLove and strife with every breath.\\nThe tides of time continue their dance,\\nAs we join in, by luck or by chance.\\n\\nSo we study the past to light the way forward, \\nHeeding warnings from stories told and heard.\\nThe future unfolds from this unending flow -\\nWhere the tides of time ultimately go.\", role='assistant'))], created=1713896579, model='https://api.clarifai.com/v2/users/anthropic/apps/completion/models/claude-2_1/outputs', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=12, completion_tokens=232, total_tokens=244))\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from litellm import completion\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||||
|
"response=completion(\n",
|
||||||
|
" model=\"clarifai/anthropic.completion.claude-2_1\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Claude-2.1 response : {response}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### OpenAI GPT-4 (Streaming)\n",
|
||||||
|
"Though clarifai doesn't support streaming, still you can call stream and get the response in standard StreamResponse format of liteLLM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason=None, index=0, delta=Delta(content=\"In the quiet corners of time's grand hall,\\nLies the tale of rise and fall.\\nFrom ancient ruins to modern sprawl,\\nHistory, the greatest story of them all.\\n\\nEmpires have risen, empires have decayed,\\nThrough the eons, memories have stayed.\\nIn the book of time, history is laid,\\nA tapestry of events, meticulously displayed.\\n\\nThe pyramids of Egypt, standing tall,\\nThe Roman Empire's mighty sprawl.\\nFrom Alexander's conquest, to the Berlin Wall,\\nHistory, a silent witness to it all.\\n\\nIn the shadow of the past we tread,\\nWhere once kings and prophets led.\\nTheir stories in our hearts are spread,\\nEchoes of their words, in our minds are read.\\n\\nBattles fought and victories won,\\nActs of courage under the sun.\\nTales of love, of deeds done,\\nIn history's grand book, they all run.\\n\\nHeroes born, legends made,\\nIn the annals of time, they'll never fade.\\nTheir triumphs and failures all displayed,\\nIn the eternal march of history's parade.\\n\\nThe ink of the past is forever dry,\\nBut its lessons, we cannot deny.\\nIn its stories, truths lie,\\nIn its wisdom, we rely.\\n\\nHistory, a mirror to our past,\\nA guide for the future vast.\\nThrough its lens, we're ever cast,\\nIn the drama of life, forever vast.\", role='assistant', function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n",
|
||||||
|
"ModelResponse(id='chatcmpl-40ae19af-3bf0-4eb4-99f2-33aec3ba84af', choices=[StreamingChoices(finish_reason='stop', index=0, delta=Delta(content=None, role=None, function_call=None, tool_calls=None), logprobs=None)], created=1714744515, model='https://api.clarifai.com/v2/users/openai/apps/chat-completion/models/GPT-4/outputs', object='chat.completion.chunk', system_fingerprint=None)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from litellm import completion\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [{\"role\": \"user\",\"content\": \"\"\"Write a poem about history?\"\"\"}]\n",
|
||||||
|
"response = completion(\n",
|
||||||
|
" model=\"clarifai/openai.chat-completion.GPT-4\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" stream=True,\n",
|
||||||
|
" api_key = \"c75cc032415e45368be331fdd2c06db0\")\n",
|
||||||
|
"\n",
|
||||||
|
"for chunk in response:\n",
|
||||||
|
" print(chunk)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
Binary file not shown.
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"$schema": "https://schema.management.azure.com/schemas/0.1.2-preview/CreateUIDefinition.MultiVm.json#",
|
||||||
|
"handler": "Microsoft.Azure.CreateUIDef",
|
||||||
|
"version": "0.1.2-preview",
|
||||||
|
"parameters": {
|
||||||
|
"config": {
|
||||||
|
"isWizard": false,
|
||||||
|
"basics": { }
|
||||||
|
},
|
||||||
|
"basics": [ ],
|
||||||
|
"steps": [ ],
|
||||||
|
"outputs": { },
|
||||||
|
"resourceTypes": [ ]
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
{
|
||||||
|
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
|
||||||
|
"contentVersion": "1.0.0.0",
|
||||||
|
"parameters": {
|
||||||
|
"imageName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "ghcr.io/berriai/litellm:main-latest"
|
||||||
|
},
|
||||||
|
"containerName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "litellm-container"
|
||||||
|
},
|
||||||
|
"dnsLabelName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "litellm"
|
||||||
|
},
|
||||||
|
"portNumber": {
|
||||||
|
"type": "int",
|
||||||
|
"defaultValue": 4000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"resources": [
|
||||||
|
{
|
||||||
|
"type": "Microsoft.ContainerInstance/containerGroups",
|
||||||
|
"apiVersion": "2021-03-01",
|
||||||
|
"name": "[parameters('containerName')]",
|
||||||
|
"location": "[resourceGroup().location]",
|
||||||
|
"properties": {
|
||||||
|
"containers": [
|
||||||
|
{
|
||||||
|
"name": "[parameters('containerName')]",
|
||||||
|
"properties": {
|
||||||
|
"image": "[parameters('imageName')]",
|
||||||
|
"resources": {
|
||||||
|
"requests": {
|
||||||
|
"cpu": 1,
|
||||||
|
"memoryInGB": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ports": [
|
||||||
|
{
|
||||||
|
"port": "[parameters('portNumber')]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"osType": "Linux",
|
||||||
|
"restartPolicy": "Always",
|
||||||
|
"ipAddress": {
|
||||||
|
"type": "Public",
|
||||||
|
"ports": [
|
||||||
|
{
|
||||||
|
"protocol": "tcp",
|
||||||
|
"port": "[parameters('portNumber')]"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"dnsNameLabel": "[parameters('dnsLabelName')]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
42
deploy/azure_resource_manager/main.bicep
Normal file
42
deploy/azure_resource_manager/main.bicep
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
param imageName string = 'ghcr.io/berriai/litellm:main-latest'
|
||||||
|
param containerName string = 'litellm-container'
|
||||||
|
param dnsLabelName string = 'litellm'
|
||||||
|
param portNumber int = 4000
|
||||||
|
|
||||||
|
resource containerGroupName 'Microsoft.ContainerInstance/containerGroups@2021-03-01' = {
|
||||||
|
name: containerName
|
||||||
|
location: resourceGroup().location
|
||||||
|
properties: {
|
||||||
|
containers: [
|
||||||
|
{
|
||||||
|
name: containerName
|
||||||
|
properties: {
|
||||||
|
image: imageName
|
||||||
|
resources: {
|
||||||
|
requests: {
|
||||||
|
cpu: 1
|
||||||
|
memoryInGB: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ports: [
|
||||||
|
{
|
||||||
|
port: portNumber
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
osType: 'Linux'
|
||||||
|
restartPolicy: 'Always'
|
||||||
|
ipAddress: {
|
||||||
|
type: 'Public'
|
||||||
|
ports: [
|
||||||
|
{
|
||||||
|
protocol: 'tcp'
|
||||||
|
port: portNumber
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dnsNameLabel: dnsLabelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,7 +24,7 @@ version: 0.2.0
|
||||||
# incremented each time you make changes to the application. Versions are not expected to
|
# incremented each time you make changes to the application. Versions are not expected to
|
||||||
# follow Semantic Versioning. They should reflect the version the application is using.
|
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||||
# It is recommended to use it with quotes.
|
# It is recommended to use it with quotes.
|
||||||
appVersion: v1.24.5
|
appVersion: v1.35.38
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
- name: "postgresql"
|
- name: "postgresql"
|
||||||
|
|
|
@ -4,6 +4,12 @@ LiteLLM allows you to:
|
||||||
* Send 1 completion call to many models: Return Fastest Response
|
* Send 1 completion call to many models: Return Fastest Response
|
||||||
* Send 1 completion call to many models: Return All Responses
|
* Send 1 completion call to many models: Return All Responses
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## Send multiple completion calls to 1 model
|
## Send multiple completion calls to 1 model
|
||||||
|
|
||||||
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
||||||
|
|
|
@ -37,11 +37,12 @@ print(response) # ["max_tokens", "tools", "tool_choice", "stream"]
|
||||||
|
|
||||||
This is a list of openai params we translate across providers.
|
This is a list of openai params we translate across providers.
|
||||||
|
|
||||||
This list is constantly being updated.
|
Use `litellm.get_supported_openai_params()` for an updated list of params for each model + provider
|
||||||
|
|
||||||
| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
| Provider | temperature | max_tokens | top_p | stream | stop | n | presence_penalty | frequency_penalty | functions | function_call | logit_bias | user | response_format | seed | tools | tool_choice | logprobs | top_logprobs | extra_headers |
|
||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--|
|
||||||
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ | ✅ |
|
||||||
|
|Anthropic| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | ✅ | ✅ | ✅ | ✅ |
|
||||||
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
|OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
|Azure OpenAI| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |✅ | ✅ | ✅ | ✅ |✅ | ✅ | | | ✅ |
|
||||||
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
|Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||||
|
@ -83,8 +84,9 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[float] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
|
@ -139,6 +141,10 @@ def completion(
|
||||||
|
|
||||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||||
|
|
||||||
|
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||||
|
|
||||||
|
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
|
||||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||||
|
|
|
@ -320,8 +320,6 @@ from litellm import embedding
|
||||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||||
litellm.vertex_location = "us-central1" # proj location
|
litellm.vertex_location = "us-central1" # proj location
|
||||||
|
|
||||||
|
|
||||||
os.environ['VOYAGE_API_KEY'] = ""
|
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="vertex_ai/textembedding-gecko",
|
model="vertex_ai/textembedding-gecko",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
|
|
|
@ -17,6 +17,14 @@ This covers:
|
||||||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
||||||
|
|
||||||
|
|
||||||
|
## [COMING SOON] AWS Marketplace Support
|
||||||
|
|
||||||
|
Deploy managed LiteLLM Proxy within your VPC.
|
||||||
|
|
||||||
|
Includes all enterprise features.
|
||||||
|
|
||||||
|
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
## Frequently Asked Questions
|
## Frequently Asked Questions
|
||||||
|
|
||||||
### What topics does Professional support cover and what SLAs do you offer?
|
### What topics does Professional support cover and what SLAs do you offer?
|
||||||
|
|
|
@ -13,7 +13,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
||||||
| >=500 | InternalServerError |
|
| >=500 | InternalServerError |
|
||||||
| N/A | ContextWindowExceededError|
|
| N/A | ContextWindowExceededError|
|
||||||
| 400 | ContentPolicyViolationError|
|
| 400 | ContentPolicyViolationError|
|
||||||
| N/A | APIConnectionError |
|
| 500 | APIConnectionError |
|
||||||
|
|
||||||
|
|
||||||
Base case we return APIConnectionError
|
Base case we return APIConnectionError
|
||||||
|
@ -74,6 +74,28 @@ except Exception as e:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage - Should you retry exception?
|
||||||
|
|
||||||
|
```
|
||||||
|
import litellm
|
||||||
|
import openai
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello, write a 20 pageg essay"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
timeout=0.01, # this will raise a timeout exception
|
||||||
|
)
|
||||||
|
except openai.APITimeoutError as e:
|
||||||
|
should_retry = litellm._should_retry(e.status_code)
|
||||||
|
print(f"should_retry: {should_retry}")
|
||||||
|
```
|
||||||
|
|
||||||
## Details
|
## Details
|
||||||
|
|
||||||
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
|
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
|
||||||
|
@ -84,23 +106,37 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
||||||
|
|
||||||
## Custom mapping list
|
## Custom mapping list
|
||||||
|
|
||||||
Base case - we return the original exception.
|
Base case - we return `litellm.APIConnectionError` exception (inherits from openai's APIConnectionError exception).
|
||||||
|
|
||||||
| | ContextWindowExceededError | AuthenticationError | InvalidRequestError | RateLimitError | ServiceUnavailableError |
|
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||||
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
|
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||||
| Anthropic | ✅ | ✅ | ✅ | ✅ | |
|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
| watsonx | | | | | | | |✓| | | |
|
||||||
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Huggingface | ✅ | ✅ | ✅ | ✅ | |
|
| anthropic | ✓ | ✓ | ✓ | ✓ | | ✓ | | | ✓ | ✓ | |
|
||||||
| Openrouter | ✅ | ✅ | ✅ | ✅ | |
|
| replicate | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | | |
|
||||||
| AI21 | ✅ | ✅ | ✅ | ✅ | |
|
| bedrock | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | ✓ | |
|
||||||
| VertexAI | | |✅ | | |
|
| sagemaker | | ✓ | ✓ | | | | | | | | |
|
||||||
| Bedrock | | |✅ | | |
|
| vertex_ai | ✓ | | ✓ | | | | ✓ | | | | ✓ |
|
||||||
| Sagemaker | | |✅ | | |
|
| palm | ✓ | ✓ | | | | | ✓ | | | | |
|
||||||
| TogetherAI | ✅ | ✅ | ✅ | ✅ | |
|
| gemini | ✓ | ✓ | | | | | ✓ | | | | |
|
||||||
| AlephAlpha | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| cloudflare | | | ✓ | | | ✓ | | | | | |
|
||||||
|
| cohere | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||||
|
| cohere_chat | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||||
|
| huggingface | ✓ | ✓ | ✓ | | | ✓ | | ✓ | ✓ | | |
|
||||||
|
| ai21 | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | | | |
|
||||||
|
| nlp_cloud | ✓ | ✓ | ✓ | | | ✓ | ✓ | ✓ | ✓ | | |
|
||||||
|
| together_ai | ✓ | ✓ | ✓ | | | ✓ | | | | | |
|
||||||
|
| aleph_alpha | | | ✓ | | | ✓ | | | | | |
|
||||||
|
| ollama | ✓ | | ✓ | | | | | | ✓ | | |
|
||||||
|
| ollama_chat | ✓ | | ✓ | | | | | | ✓ | | |
|
||||||
|
| vllm | | | | | | ✓ | ✓ | | | | |
|
||||||
|
| azure | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | | ✓ | | |
|
||||||
|
|
||||||
|
- "✓" indicates that the specified `custom_llm_provider` can raise the corresponding exception.
|
||||||
|
- Empty cells indicate the lack of association or that the provider does not raise that particular exception type as indicated by the function.
|
||||||
|
|
||||||
|
|
||||||
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.
|
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.
|
||||||
|
|
|
@ -47,3 +47,12 @@ Pricing is based on usage. We can figure out a price that works for your team, o
|
||||||
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
||||||
|
|
||||||
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
## Feature List
|
||||||
|
|
||||||
|
- Easy way to add/remove models
|
||||||
|
- 100% uptime even when models are added/removed
|
||||||
|
- custom callback webhooks
|
||||||
|
- your domain name with HTTPS
|
||||||
|
- Ability to create/delete User API keys
|
||||||
|
- Reasonable set monthly cost
|
|
@ -14,14 +14,14 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['OPENAI_API_KEY'] = ""
|
os.environ['OPENAI_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="gpt-3.5-turbo")
|
chat = ChatLiteLLM(model="gpt-3.5-turbo")
|
||||||
|
@ -30,7 +30,7 @@ messages = [
|
||||||
content="what model are you"
|
content="what model are you"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -39,14 +39,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['ANTHROPIC_API_KEY'] = ""
|
os.environ['ANTHROPIC_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="claude-2", temperature=0.3)
|
chat = ChatLiteLLM(model="claude-2", temperature=0.3)
|
||||||
|
@ -55,7 +55,7 @@ messages = [
|
||||||
content="what model are you"
|
content="what model are you"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -64,14 +64,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['REPLICATE_API_TOKEN'] = ""
|
os.environ['REPLICATE_API_TOKEN'] = ""
|
||||||
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
|
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
|
||||||
|
@ -80,7 +80,7 @@ messages = [
|
||||||
content="what model are you?"
|
content="what model are you?"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -89,14 +89,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['COHERE_API_KEY'] = ""
|
os.environ['COHERE_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="command-nightly")
|
chat = ChatLiteLLM(model="command-nightly")
|
||||||
|
@ -105,32 +105,9 @@ messages = [
|
||||||
content="what model are you?"
|
content="what model are you?"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="palm" label="PaLM - Google">
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from langchain.chat_models import ChatLiteLLM
|
|
||||||
from langchain.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
SystemMessagePromptTemplate,
|
|
||||||
AIMessagePromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
os.environ['PALM_API_KEY'] = ""
|
|
||||||
chat = ChatLiteLLM(model="palm/chat-bison")
|
|
||||||
messages = [
|
|
||||||
HumanMessage(
|
|
||||||
content="what model are you?"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
chat(messages)
|
|
||||||
```
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
|
@ -94,9 +94,10 @@ print(response)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Set Custom Trace ID, Trace User ID and Tags
|
### Set Custom Trace ID, Trace User ID, Trace Metadata, Trace Version, Trace Release and Tags
|
||||||
|
|
||||||
|
Pass `trace_id`, `trace_user_id`, `trace_metadata`, `trace_version`, `trace_release`, `tags` in `metadata`
|
||||||
|
|
||||||
Pass `trace_id`, `trace_user_id` in `metadata`
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -121,12 +122,21 @@ response = completion(
|
||||||
metadata={
|
metadata={
|
||||||
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
||||||
"generation_id": "gen-id22", # set langfuse Generation ID
|
"generation_id": "gen-id22", # set langfuse Generation ID
|
||||||
|
"version": "test-generation-version" # set langfuse Generation Version
|
||||||
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
||||||
"session_id": "session-1", # set langfuse Session ID
|
"session_id": "session-1", # set langfuse Session ID
|
||||||
"tags": ["tag1", "tag2"] # set langfuse Tags
|
"tags": ["tag1", "tag2"], # set langfuse Tags
|
||||||
"trace_id": "trace-id22", # set langfuse Trace ID
|
"trace_id": "trace-id22", # set langfuse Trace ID
|
||||||
|
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
|
||||||
|
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
|
||||||
|
"trace_release": "test-trace-release", # set langfuse Trace Release
|
||||||
### OR ###
|
### OR ###
|
||||||
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
||||||
|
### OR enforce that certain fields are trace overwritten in the trace during the continuation ###
|
||||||
|
"existing_trace_id": "trace-id22",
|
||||||
|
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
|
||||||
|
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
|
||||||
|
"debug_langfuse": True, # Will log the exact metadata sent to litellm for the trace/generation as `metadata_passed_to_litellm`
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -134,6 +144,38 @@ print(response)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Trace & Generation Parameters
|
||||||
|
|
||||||
|
#### Trace Specific Parameters
|
||||||
|
|
||||||
|
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default
|
||||||
|
* `trace_name` - Name of the trace, auto-generated by default
|
||||||
|
* `session_id` - Session identifier for the trace, defaults to `None`
|
||||||
|
* `trace_version` - Version for the trace, defaults to value for `version`
|
||||||
|
* `trace_release` - Release for the trace, defaults to `None`
|
||||||
|
* `trace_metadata` - Metadata for the trace, defaults to `None`
|
||||||
|
* `trace_user_id` - User identifier for the trace, defaults to completion argument `user`
|
||||||
|
* `tags` - Tags for the trace, defeaults to `None`
|
||||||
|
|
||||||
|
##### Updatable Parameters on Continuation
|
||||||
|
|
||||||
|
The following parameters can be updated on a continuation of a trace by passing in the following values into the `update_trace_keys` in the metadata of the completion.
|
||||||
|
|
||||||
|
* `input` - Will set the traces input to be the input of this latest generation
|
||||||
|
* `output` - Will set the traces output to be the output of this generation
|
||||||
|
* `trace_version` - Will set the trace version to be the provided value (To use the latest generations version instead, use `version`)
|
||||||
|
* `trace_release` - Will set the trace release to be the provided value
|
||||||
|
* `trace_metadata` - Will set the trace metadata to the provided value
|
||||||
|
* `trace_user_id` - Will set the trace user id to the provided value
|
||||||
|
|
||||||
|
#### Generation Specific Parameters
|
||||||
|
|
||||||
|
* `generation_id` - Identifier for the generation, auto-generated by default
|
||||||
|
* `generation_name` - Identifier for the generation, auto-generated by default
|
||||||
|
* `prompt` - Langfuse prompt object used for the generation, defaults to None
|
||||||
|
|
||||||
|
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
|
||||||
|
|
||||||
### Use LangChain ChatLiteLLM + Langfuse
|
### Use LangChain ChatLiteLLM + Langfuse
|
||||||
Pass `trace_user_id`, `session_id` in model_kwargs
|
Pass `trace_user_id`, `session_id` in model_kwargs
|
||||||
```python
|
```python
|
||||||
|
@ -171,8 +213,20 @@ chat(messages)
|
||||||
|
|
||||||
## Redacting Messages, Response Content from Langfuse Logging
|
## Redacting Messages, Response Content from Langfuse Logging
|
||||||
|
|
||||||
|
### Redact Messages and Responses from all Langfuse Logging
|
||||||
|
|
||||||
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
|
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to langfuse, but request metadata will still be logged.
|
||||||
|
|
||||||
|
### Redact Messages and Responses from specific Langfuse Logging
|
||||||
|
|
||||||
|
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
|
||||||
|
|
||||||
|
Setting `mask_input` to `True` will mask the input from being logged for this call
|
||||||
|
|
||||||
|
Setting `mask_output` to `True` will make the output from being logged for this call.
|
||||||
|
|
||||||
|
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
|
||||||
|
|
||||||
## Troubleshooting & Errors
|
## Troubleshooting & Errors
|
||||||
### Data not getting logged to Langfuse ?
|
### Data not getting logged to Langfuse ?
|
||||||
- Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse
|
- Ensure you're on the latest version of langfuse `pip install langfuse -U`. The latest version allows litellm to log JSON input/outputs to langfuse
|
||||||
|
|
|
@ -535,7 +535,8 @@ print(response)
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|----------------------|---------------------------------------------|
|
|----------------------|---------------------------------------------|
|
||||||
| Titan Embeddings - G1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` |
|
||||||
|
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
||||||
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
|
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
|
||||||
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |
|
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |
|
||||||
|
|
||||||
|
|
177
docs/my-website/docs/providers/clarifai.md
Normal file
177
docs/my-website/docs/providers/clarifai.md
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
|
||||||
|
# Clarifai
|
||||||
|
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
|
||||||
|
|
||||||
|
## Pre-Requisites
|
||||||
|
|
||||||
|
`pip install clarifai`
|
||||||
|
|
||||||
|
`pip install litellm`
|
||||||
|
|
||||||
|
## Required Environment Variables
|
||||||
|
To obtain your Clarifai Personal access token follow this [link](https://docs.clarifai.com/clarifai-basics/authentication/personal-access-tokens/). Optionally the PAT can also be passed in `completion` function.
|
||||||
|
|
||||||
|
```python
|
||||||
|
os.environ["CALRIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
os.environ["CLARIFAI_API_KEY"] = ""
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="clarifai/mistralai.completion.mistral-large",
|
||||||
|
messages=[{ "content": "Tell me a joke about physics?","role": "user"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-572701ee-9ab2-411c-ac75-46c1ba18e781",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 1,
|
||||||
|
"message": {
|
||||||
|
"content": "Sure, here's a physics joke for you:\n\nWhy can't you trust an atom?\n\nBecause they make up everything!",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1714410197,
|
||||||
|
"model": "https://api.clarifai.com/v2/users/mistralai/apps/completion/models/mistral-large/outputs",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": null,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 14,
|
||||||
|
"completion_tokens": 24,
|
||||||
|
"total_tokens": 38
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Clarifai models
|
||||||
|
liteLLM supports non-streaming requests to all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24)
|
||||||
|
|
||||||
|
Example Usage - Note: liteLLM supports all models deployed on Clarifai
|
||||||
|
|
||||||
|
## Llama LLMs
|
||||||
|
| Model Name | Function Call |
|
||||||
|
---------------------------|---------------------------------|
|
||||||
|
| clarifai/meta.Llama-2.llama2-7b-chat | `completion('clarifai/meta.Llama-2.llama2-7b-chat', messages)`
|
||||||
|
| clarifai/meta.Llama-2.llama2-13b-chat | `completion('clarifai/meta.Llama-2.llama2-13b-chat', messages)`
|
||||||
|
| clarifai/meta.Llama-2.llama2-70b-chat | `completion('clarifai/meta.Llama-2.llama2-70b-chat', messages)` |
|
||||||
|
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
|
||||||
|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
|
||||||
|
|
||||||
|
## Mistal LLMs
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|---------------------------------------------|------------------------------------------------------------------------|
|
||||||
|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |
|
||||||
|
| clarifai/mistralai.completion.mistral-large | `completion('clarifai/mistralai.completion.mistral-large', messages)` |
|
||||||
|
| clarifai/mistralai.completion.mistral-medium | `completion('clarifai/mistralai.completion.mistral-medium', messages)` |
|
||||||
|
| clarifai/mistralai.completion.mistral-small | `completion('clarifai/mistralai.completion.mistral-small', messages)` |
|
||||||
|
| clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1 | `completion('clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1', messages)`
|
||||||
|
| clarifai/mistralai.completion.mistral-7B-OpenOrca | `completion('clarifai/mistralai.completion.mistral-7B-OpenOrca', messages)` |
|
||||||
|
| clarifai/mistralai.completion.openHermes-2-mistral-7B | `completion('clarifai/mistralai.completion.openHermes-2-mistral-7B', messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Jurassic LLMs
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/ai21.complete.Jurassic2-Grande | `completion('clarifai/ai21.complete.Jurassic2-Grande', messages)` |
|
||||||
|
| clarifai/ai21.complete.Jurassic2-Grande-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Grande-Instruct', messages)` |
|
||||||
|
| clarifai/ai21.complete.Jurassic2-Jumbo-Instruct | `completion('clarifai/ai21.complete.Jurassic2-Jumbo-Instruct', messages)` |
|
||||||
|
| clarifai/ai21.complete.Jurassic2-Jumbo | `completion('clarifai/ai21.complete.Jurassic2-Jumbo', messages)` |
|
||||||
|
| clarifai/ai21.complete.Jurassic2-Large | `completion('clarifai/ai21.complete.Jurassic2-Large', messages)` |
|
||||||
|
|
||||||
|
## Wizard LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/wizardlm.generate.wizardCoder-Python-34B | `completion('clarifai/wizardlm.generate.wizardCoder-Python-34B', messages)` |
|
||||||
|
| clarifai/wizardlm.generate.wizardLM-70B | `completion('clarifai/wizardlm.generate.wizardLM-70B', messages)` |
|
||||||
|
| clarifai/wizardlm.generate.wizardLM-13B | `completion('clarifai/wizardlm.generate.wizardLM-13B', messages)` |
|
||||||
|
| clarifai/wizardlm.generate.wizardCoder-15B | `completion('clarifai/wizardlm.generate.wizardCoder-15B', messages)` |
|
||||||
|
|
||||||
|
## Anthropic models
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/anthropic.completion.claude-v1 | `completion('clarifai/anthropic.completion.claude-v1', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-instant-1_2 | `completion('clarifai/anthropic.completion.claude-instant-1_2', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-instant | `completion('clarifai/anthropic.completion.claude-instant', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-v2 | `completion('clarifai/anthropic.completion.claude-v2', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-2_1 | `completion('clarifai/anthropic.completion.claude-2_1', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-3-opus | `completion('clarifai/anthropic.completion.claude-3-opus', messages)` |
|
||||||
|
| clarifai/anthropic.completion.claude-3-sonnet | `completion('clarifai/anthropic.completion.claude-3-sonnet', messages)` |
|
||||||
|
|
||||||
|
## OpenAI GPT LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/openai.chat-completion.GPT-4 | `completion('clarifai/openai.chat-completion.GPT-4', messages)` |
|
||||||
|
| clarifai/openai.chat-completion.GPT-3_5-turbo | `completion('clarifai/openai.chat-completion.GPT-3_5-turbo', messages)` |
|
||||||
|
| clarifai/openai.chat-completion.gpt-4-turbo | `completion('clarifai/openai.chat-completion.gpt-4-turbo', messages)` |
|
||||||
|
| clarifai/openai.completion.gpt-3_5-turbo-instruct | `completion('clarifai/openai.completion.gpt-3_5-turbo-instruct', messages)` |
|
||||||
|
|
||||||
|
## GCP LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/gcp.generate.gemini-1_5-pro | `completion('clarifai/gcp.generate.gemini-1_5-pro', messages)` |
|
||||||
|
| clarifai/gcp.generate.imagen-2 | `completion('clarifai/gcp.generate.imagen-2', messages)` |
|
||||||
|
| clarifai/gcp.generate.code-gecko | `completion('clarifai/gcp.generate.code-gecko', messages)` |
|
||||||
|
| clarifai/gcp.generate.code-bison | `completion('clarifai/gcp.generate.code-bison', messages)` |
|
||||||
|
| clarifai/gcp.generate.text-bison | `completion('clarifai/gcp.generate.text-bison', messages)` |
|
||||||
|
| clarifai/gcp.generate.gemma-2b-it | `completion('clarifai/gcp.generate.gemma-2b-it', messages)` |
|
||||||
|
| clarifai/gcp.generate.gemma-7b-it | `completion('clarifai/gcp.generate.gemma-7b-it', messages)` |
|
||||||
|
| clarifai/gcp.generate.gemini-pro | `completion('clarifai/gcp.generate.gemini-pro', messages)` |
|
||||||
|
| clarifai/gcp.generate.gemma-1_1-7b-it | `completion('clarifai/gcp.generate.gemma-1_1-7b-it', messages)` |
|
||||||
|
|
||||||
|
## Cohere LLMs
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/cohere.generate.cohere-generate-command | `completion('clarifai/cohere.generate.cohere-generate-command', messages)` |
|
||||||
|
clarifai/cohere.generate.command-r-plus' | `completion('clarifai/clarifai/cohere.generate.command-r-plus', messages)`|
|
||||||
|
|
||||||
|
## Databricks LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/databricks.drbx.dbrx-instruct | `completion('clarifai/databricks.drbx.dbrx-instruct', messages)` |
|
||||||
|
| clarifai/databricks.Dolly-v2.dolly-v2-12b | `completion('clarifai/databricks.Dolly-v2.dolly-v2-12b', messages)`|
|
||||||
|
|
||||||
|
## Microsoft LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/microsoft.text-generation.phi-2 | `completion('clarifai/microsoft.text-generation.phi-2', messages)` |
|
||||||
|
| clarifai/microsoft.text-generation.phi-1_5 | `completion('clarifai/microsoft.text-generation.phi-1_5', messages)`|
|
||||||
|
|
||||||
|
## Salesforce models
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|-----------------------------------------------------------|-------------------------------------------------------------------------------|
|
||||||
|
| clarifai/salesforce.blip.general-english-image-caption-blip-2 | `completion('clarifai/salesforce.blip.general-english-image-caption-blip-2', messages)` |
|
||||||
|
| clarifai/salesforce.xgen.xgen-7b-8k-instruct | `completion('clarifai/salesforce.xgen.xgen-7b-8k-instruct', messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Other Top performing LLMs
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|---------------------------------------------------|---------------------------------------------------------------------|
|
||||||
|
| clarifai/deci.decilm.deciLM-7B-instruct | `completion('clarifai/deci.decilm.deciLM-7B-instruct', messages)` |
|
||||||
|
| clarifai/upstage.solar.solar-10_7b-instruct | `completion('clarifai/upstage.solar.solar-10_7b-instruct', messages)` |
|
||||||
|
| clarifai/openchat.openchat.openchat-3_5-1210 | `completion('clarifai/openchat.openchat.openchat-3_5-1210', messages)` |
|
||||||
|
| clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B | `completion('clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B', messages)` |
|
||||||
|
| clarifai/fblgit.una-cybertron.una-cybertron-7b-v2 | `completion('clarifai/fblgit.una-cybertron.una-cybertron-7b-v2', messages)` |
|
||||||
|
| clarifai/tiiuae.falcon.falcon-40b-instruct | `completion('clarifai/tiiuae.falcon.falcon-40b-instruct', messages)` |
|
||||||
|
| clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat | `completion('clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat', messages)` |
|
||||||
|
| clarifai/bigcode.code.StarCoder | `completion('clarifai/bigcode.code.StarCoder', messages)` |
|
||||||
|
| clarifai/mosaicml.mpt.mpt-7b-instruct | `completion('clarifai/mosaicml.mpt.mpt-7b-instruct', messages)` |
|
54
docs/my-website/docs/providers/deepseek.md
Normal file
54
docs/my-website/docs/providers/deepseek.md
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Deepseek
|
||||||
|
https://deepseek.com/
|
||||||
|
|
||||||
|
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['DEEPSEEK_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage - Streaming
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Models - ALL Deepseek Models Supported!
|
||||||
|
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||||
|
|
||||||
|
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -40,9 +45,58 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: wizard-coder
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "wizard-coder",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||||
|
|
||||||
|
Append `conversational` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/conversational/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://my-endpoint.huggingface.cloud"
|
api_base="https://my-endpoint.huggingface.cloud"
|
||||||
)
|
)
|
||||||
|
@ -62,7 +116,123 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: blenderbot
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "blenderbot",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="classification" label="Text Classification">
|
||||||
|
|
||||||
|
Append `text-classification` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-classification/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# [OPTIONAL] set env var
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
|
|
||||||
|
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||||
|
|
||||||
|
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "bert-classifier",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||||
|
|
||||||
|
Append `text-generation` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-generation/<model-name>`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/roneneldan/TinyStories-3M",
|
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
)
|
)
|
||||||
|
|
|
@ -45,13 +45,13 @@ for chunk in response:
|
||||||
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|----------------|--------------------------------------------------------------|
|
||||||
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
|
| Mistral Small | `completion(model="mistral/mistral-small-latest", messages)` |
|
||||||
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
| Mistral Medium | `completion(model="mistral/mistral-medium-latest", messages)`|
|
||||||
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
| Mistral Large | `completion(model="mistral/mistral-large-latest", messages)` |
|
||||||
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
|
| Mistral 7B | `completion(model="mistral/open-mistral-7b", messages)` |
|
||||||
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
| Mixtral 8x7B | `completion(model="mistral/open-mixtral-8x7b", messages)` |
|
||||||
|
| Mixtral 8x22B | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||||
|
|
||||||
## Function Calling
|
## Function Calling
|
||||||
|
|
||||||
|
@ -116,6 +116,6 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| mistral-embed | `embedding(model="mistral/mistral-embed", input)` |
|
| Mistral Embeddings | `embedding(model="mistral/mistral-embed", input)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||||
|
|
||||||
# openai call
|
# openai call
|
||||||
response = completion(
|
response = completion(
|
||||||
model = "gpt-3.5-turbo",
|
model = "gpt-4o",
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}]
|
messages=[{ "content": "Hello, how are you?","role": "user"}]
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
@ -163,6 +163,8 @@ os.environ["OPENAI_API_BASE"] = "openaiai-api-base" # OPTIONAL
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|-----------------------|-----------------------------------------------------------------|
|
|-----------------------|-----------------------------------------------------------------|
|
||||||
|
| gpt-4o | `response = completion(model="gpt-4o", messages=messages)` |
|
||||||
|
| gpt-4o-2024-05-13 | `response = completion(model="gpt-4o-2024-05-13", messages=messages)` |
|
||||||
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
| gpt-4-turbo | `response = completion(model="gpt-4-turbo", messages=messages)` |
|
||||||
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
| gpt-4-turbo-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||||
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
| gpt-4-0125-preview | `response = completion(model="gpt-4-0125-preview", messages=messages)` |
|
||||||
|
|
247
docs/my-website/docs/providers/predibase.md
Normal file
247
docs/my-website/docs/providers/predibase.md
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 🆕 Predibase
|
||||||
|
|
||||||
|
LiteLLM supports all models on Predibase
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
### API KEYS
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
|
||||||
|
|
||||||
|
# predibase llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="llama-3",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama-3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Advanced Usage - Prompt Formatting
|
||||||
|
|
||||||
|
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
|
||||||
|
|
||||||
|
To apply a custom prompt template:
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
|
||||||
|
# Create your own custom prompt template
|
||||||
|
litellm.register_prompt_template(
|
||||||
|
model="togethercomputer/LLaMA-2-7B-32K",
|
||||||
|
initial_prompt_value="You are a good assistant" # [OPTIONAL]
|
||||||
|
roles={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||||
|
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "[INST] ", # [OPTIONAL]
|
||||||
|
"post_message": " [/INST]" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "\n" # [OPTIONAL]
|
||||||
|
"post_message": "\n" # [OPTIONAL]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
def predibase_custom_model():
|
||||||
|
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
|
||||||
|
response = completion(model=model, messages=messages)
|
||||||
|
print(response['choices'][0]['message']['content'])
|
||||||
|
return response
|
||||||
|
|
||||||
|
predibase_custom_model()
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Model-specific parameters
|
||||||
|
model_list:
|
||||||
|
- model_name: mistral-7b # model alias
|
||||||
|
litellm_params: # actual params for litellm.completion()
|
||||||
|
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
initial_prompt_value: "\n"
|
||||||
|
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||||
|
final_prompt_value: "\n"
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
max_tokens: 4096
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Passing additional params - max_tokens, temperature
|
||||||
|
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibae llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.5
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
max_tokens: 20
|
||||||
|
temperature: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Passings Predibase specific params - adapter_id, adapter_source,
|
||||||
|
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
|
||||||
|
|
||||||
|
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibase llama3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
adapter_id="my_repo/3",
|
||||||
|
adapter_soruce="pbase",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
adapter_id: my_repo/3
|
||||||
|
adapter_source: pbase
|
||||||
|
```
|
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Triton Inference Server
|
||||||
|
|
||||||
|
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/<your-triton-model>",
|
||||||
|
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=["hello from litellm"],
|
||||||
|
model="my-triton-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
|
@ -477,6 +477,36 @@ print(response)
|
||||||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Embedding Models
|
||||||
|
|
||||||
|
#### Usage - Embedding
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding
|
||||||
|
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||||
|
litellm.vertex_location = "us-central1" # proj location
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="vertex_ai/textembedding-gecko",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supported Embedding Models
|
||||||
|
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
|
||||||
|
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
|
||||||
|
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
||||||
|
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
|
||||||
|
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||||
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
|
||||||
## Extra
|
## Extra
|
||||||
|
|
||||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||||
|
@ -520,6 +550,12 @@ def load_vertex_ai_credentials():
|
||||||
|
|
||||||
### Using GCP Service Account
|
### Using GCP Service Account
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Trying to deploy LiteLLM on Google Cloud Run? Tutorial [here](https://docs.litellm.ai/docs/proxy/deploy#deploy-on-google-cloud-run)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
1. Figure out the Service Account bound to the Google Cloud Run service
|
1. Figure out the Service Account bound to the Google Cloud Run service
|
||||||
|
|
||||||
<Image img={require('../../img/gcp_acc_1.png')} />
|
<Image img={require('../../img/gcp_acc_1.png')} />
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
# 🚨 Alerting
|
# 🚨 Alerting
|
||||||
|
|
||||||
Get alerts for:
|
Get alerts for:
|
||||||
|
|
||||||
- Hanging LLM api calls
|
- Hanging LLM api calls
|
||||||
- Failed LLM api calls
|
- Failed LLM api calls
|
||||||
- Slow LLM api calls
|
- Slow LLM api calls
|
||||||
- Budget Tracking per key/user:
|
- Budget Tracking per key/user:
|
||||||
- When a User/Key crosses their Budget
|
- When a User/Key crosses their Budget
|
||||||
- When a User/Key is 15% away from crossing their Budget
|
- When a User/Key is 15% away from crossing their Budget
|
||||||
|
- Spend Reports - Weekly & Monthly spend per Team, Tag
|
||||||
- Failed db read/writes
|
- Failed db read/writes
|
||||||
|
|
||||||
|
As a bonus, you can also get "daily reports" posted to your slack channel.
|
||||||
|
These reports contain key metrics like:
|
||||||
|
|
||||||
|
- Top 5 deployments with most failed requests
|
||||||
|
- Top 5 slowest deployments
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Set up a slack alert channel to receive alerts from proxy.
|
Set up a slack alert channel to receive alerts from proxy.
|
||||||
|
@ -20,7 +28,8 @@ Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||||
|
|
||||||
### Step 2: Update config.yaml
|
### Step 2: Update config.yaml
|
||||||
|
|
||||||
Let's save a bad key to our proxy
|
- Set `SLACK_WEBHOOK_URL` in your proxy env to enable Slack alerts.
|
||||||
|
- Just for testing purposes, let's save a bad key to our proxy.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
|
@ -33,13 +42,11 @@ general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||||
|
|
||||||
|
environment_variables:
|
||||||
|
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
|
||||||
|
SLACK_DAILY_REPORT_FREQUENCY: "86400" # 24 hours; Optional: defaults to 12 hours
|
||||||
```
|
```
|
||||||
|
|
||||||
Set `SLACK_WEBHOOK_URL` in your proxy env
|
|
||||||
|
|
||||||
```shell
|
|
||||||
SLACK_WEBHOOK_URL: "https://hooks.slack.com/services/<>/<>/<>"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Start proxy
|
### Step 3: Start proxy
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,136 @@
|
||||||
# Cost Tracking - Azure
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 💸 Spend Tracking
|
||||||
|
|
||||||
|
Track spend for keys, users, and teams across 100+ LLMs.
|
||||||
|
|
||||||
|
## Getting Spend Reports - To Charge Other Teams, API Keys
|
||||||
|
|
||||||
|
Use the `/global/spend/report` endpoint to get daily spend per team, with a breakdown of spend per API Key, Model
|
||||||
|
|
||||||
|
### Example Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2023-04-01&end_date=2024-06-30' \
|
||||||
|
-H 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Response
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="response" label="Expected Response">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"group_by_day": "2024-04-30T00:00:00+00:00",
|
||||||
|
"teams": [
|
||||||
|
{
|
||||||
|
"team_name": "Prod Team",
|
||||||
|
"total_spend": 0.0015265,
|
||||||
|
"metadata": [ # see the spend by unique(key + model)
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"spend": 0.00123,
|
||||||
|
"total_tokens": 28,
|
||||||
|
"api_key": "88dc28.." # the hashed api key
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"spend": 0.00123,
|
||||||
|
"total_tokens": 28,
|
||||||
|
"api_key": "a73dc2.." # the hashed api key
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "chatgpt-v-2",
|
||||||
|
"spend": 0.000214,
|
||||||
|
"total_tokens": 122,
|
||||||
|
"api_key": "898c28.." # the hashed api key
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"spend": 0.0000825,
|
||||||
|
"total_tokens": 85,
|
||||||
|
"api_key": "84dc28.." # the hashed api key
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="py-script" label="Script to Parse Response (Python)">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
url = 'http://localhost:4000/global/spend/report'
|
||||||
|
params = {
|
||||||
|
'start_date': '2023-04-01',
|
||||||
|
'end_date': '2024-06-30'
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'Authorization': 'Bearer sk-1234'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make the GET request
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
spend_report = response.json()
|
||||||
|
|
||||||
|
for row in spend_report:
|
||||||
|
date = row["group_by_day"]
|
||||||
|
teams = row["teams"]
|
||||||
|
for team in teams:
|
||||||
|
team_name = team["team_name"]
|
||||||
|
total_spend = team["total_spend"]
|
||||||
|
metadata = team["metadata"]
|
||||||
|
|
||||||
|
print(f"Date: {date}")
|
||||||
|
print(f"Team: {team_name}")
|
||||||
|
print(f"Total Spend: {total_spend}")
|
||||||
|
print("Metadata: ", metadata)
|
||||||
|
print()
|
||||||
|
```
|
||||||
|
|
||||||
|
Output from script
|
||||||
|
```shell
|
||||||
|
# Date: 2024-05-11T00:00:00+00:00
|
||||||
|
# Team: local_test_team
|
||||||
|
# Total Spend: 0.003675099999999999
|
||||||
|
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.003675099999999999, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 3105}]
|
||||||
|
|
||||||
|
# Date: 2024-05-13T00:00:00+00:00
|
||||||
|
# Team: Unassigned Team
|
||||||
|
# Total Spend: 3.4e-05
|
||||||
|
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 3.4e-05, 'api_key': '9569d13c9777dba68096dea49b0b03e0aaf4d2b65d4030eda9e8a2733c3cd6e0', 'total_tokens': 50}]
|
||||||
|
|
||||||
|
# Date: 2024-05-13T00:00:00+00:00
|
||||||
|
# Team: central
|
||||||
|
# Total Spend: 0.000684
|
||||||
|
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.000684, 'api_key': '0323facdf3af551594017b9ef162434a9b9a8ca1bbd9ccbd9d6ce173b1015605', 'total_tokens': 498}]
|
||||||
|
|
||||||
|
# Date: 2024-05-13T00:00:00+00:00
|
||||||
|
# Team: local_test_team
|
||||||
|
# Total Spend: 0.0005715000000000001
|
||||||
|
# Metadata: [{'model': 'gpt-3.5-turbo', 'spend': 0.0005715000000000001, 'api_key': 'b94d5e0bc3a71a573917fe1335dc0c14728c7016337451af9714924ff3a729db', 'total_tokens': 423}]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## Spend Tracking for Azure
|
||||||
|
|
||||||
Set base model for cost tracking azure image-gen call
|
Set base model for cost tracking azure image-gen call
|
||||||
|
|
||||||
## Image Generation
|
### Image Generation
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
|
@ -17,7 +145,7 @@ model_list:
|
||||||
mode: image_generation
|
mode: image_generation
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat Completions / Embeddings
|
### Chat Completions / Embeddings
|
||||||
|
|
||||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
|
||||||
|
|
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Region-based Routing
|
||||||
|
|
||||||
|
Route specific customers to eu-only models.
|
||||||
|
|
||||||
|
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
|
||||||
|
|
||||||
|
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
|
||||||
|
|
||||||
|
### 1. Create customer with region-specification
|
||||||
|
|
||||||
|
Use the litellm 'end-user' object for this.
|
||||||
|
|
||||||
|
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"user_id" : "ishaan-jaff-45",
|
||||||
|
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Add eu models to model-group
|
||||||
|
|
||||||
|
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the proxy
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Test it!
|
||||||
|
|
||||||
|
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the meaning of the universe? 1234"
|
||||||
|
}],
|
||||||
|
"user": "ishaan-jaff-45" # 👈 USER ID
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected API Base in response headers
|
||||||
|
|
||||||
|
```
|
||||||
|
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||||
|
```
|
||||||
|
|
||||||
|
### FAQ
|
||||||
|
|
||||||
|
**What happens if there are no available models for that region?**
|
||||||
|
|
||||||
|
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.
|
|
@ -3,7 +3,7 @@ import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
|
||||||
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina
|
# 🔎 Logging - Custom Callbacks, DataDog, Langfuse, s3 Bucket, Sentry, OpenTelemetry, Athina, Azure Content-Safety
|
||||||
|
|
||||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB, s3 Bucket
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTeleme
|
||||||
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
- [Logging to Sentry](#logging-proxy-inputoutput---sentry)
|
||||||
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
- [Logging to Traceloop (OpenTelemetry)](#logging-proxy-inputoutput-traceloop-opentelemetry)
|
||||||
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
- [Logging to Athina](#logging-proxy-inputoutput-athina)
|
||||||
|
- [(BETA) Moderation with Azure Content-Safety](#moderation-with-azure-content-safety)
|
||||||
|
|
||||||
## Custom Callback Class [Async]
|
## Custom Callback Class [Async]
|
||||||
Use this when you want to run custom callbacks in `python`
|
Use this when you want to run custom callbacks in `python`
|
||||||
|
@ -914,39 +915,72 @@ Test Request
|
||||||
litellm --test
|
litellm --test
|
||||||
```
|
```
|
||||||
|
|
||||||
## Logging Proxy Input/Output Traceloop (OpenTelemetry)
|
## Logging Proxy Input/Output in OpenTelemetry format using Traceloop's OpenLLMetry
|
||||||
|
|
||||||
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format
|
[OpenLLMetry](https://github.com/traceloop/openllmetry) _(built and maintained by Traceloop)_ is a set of extensions
|
||||||
|
built on top of [OpenTelemetry](https://opentelemetry.io/) that gives you complete observability over your LLM
|
||||||
|
application. Because it uses OpenTelemetry under the
|
||||||
|
hood, [it can be connected to various observability solutions](https://www.traceloop.com/docs/openllmetry/integrations/introduction)
|
||||||
|
like:
|
||||||
|
|
||||||
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop
|
* [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop)
|
||||||
|
* [Axiom](https://www.traceloop.com/docs/openllmetry/integrations/axiom)
|
||||||
|
* [Azure Application Insights](https://www.traceloop.com/docs/openllmetry/integrations/azure)
|
||||||
|
* [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
||||||
|
* [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace)
|
||||||
|
* [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana)
|
||||||
|
* [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb)
|
||||||
|
* [HyperDX](https://www.traceloop.com/docs/openllmetry/integrations/hyperdx)
|
||||||
|
* [Instana](https://www.traceloop.com/docs/openllmetry/integrations/instana)
|
||||||
|
* [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic)
|
||||||
|
* [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
||||||
|
* [Service Now Cloud Observability](https://www.traceloop.com/docs/openllmetry/integrations/service-now)
|
||||||
|
* [Sentry](https://www.traceloop.com/docs/openllmetry/integrations/sentry)
|
||||||
|
* [SigNoz](https://www.traceloop.com/docs/openllmetry/integrations/signoz)
|
||||||
|
* [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk)
|
||||||
|
|
||||||
**Step 1** Install traceloop-sdk and set Traceloop API key
|
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` to achieve this, steps are listed below.
|
||||||
|
|
||||||
|
**Step 1:** Install the SDK
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install traceloop-sdk -U
|
pip install traceloop-sdk
|
||||||
```
|
```
|
||||||
|
|
||||||
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
**Step 2:** Configure Environment Variable for trace exporting
|
||||||
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
|
||||||
|
You will need to configure where to export your traces. Environment variables will control this, example: For Traceloop
|
||||||
|
you should use `TRACELOOP_API_KEY`, whereas for Datadog you use `TRACELOOP_BASE_URL`. For more
|
||||||
|
visit [the Integrations Catalog](https://www.traceloop.com/docs/openllmetry/integrations/introduction).
|
||||||
|
|
||||||
|
If you are using Datadog as the observability solutions then you can set `TRACELOOP_BASE_URL` as:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
TRACELOOP_BASE_URL=http://<datadog-agent-hostname>:4318
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||||
|
|
||||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
|
api_key: my-fake-key # replace api_key with actual key
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["traceloop"]
|
success_callback: [ "traceloop" ]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Step 3**: Start the proxy, make a test request
|
**Step 4**: Start the proxy, make a test request
|
||||||
|
|
||||||
Start proxy
|
Start proxy
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
litellm --config config.yaml --debug
|
litellm --config config.yaml --debug
|
||||||
```
|
```
|
||||||
|
|
||||||
Test Request
|
Test Request
|
||||||
|
|
||||||
```
|
```
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
|
@ -1004,3 +1038,86 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## (BETA) Moderation with Azure Content Safety
|
||||||
|
|
||||||
|
[Azure Content-Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) is a Microsoft Azure service that provides content moderation APIs to detect potential offensive, harmful, or risky content in text.
|
||||||
|
|
||||||
|
We will use the `--config` to set `litellm.success_callback = ["azure_content_safety"]` this will moderate all LLM calls using Azure Content Safety.
|
||||||
|
|
||||||
|
**Step 0** Deploy Azure Content Safety
|
||||||
|
|
||||||
|
Deploy an Azure Content-Safety instance from the Azure Portal and get the `endpoint` and `key`.
|
||||||
|
|
||||||
|
**Step 1** Set Athina API key
|
||||||
|
|
||||||
|
```shell
|
||||||
|
AZURE_CONTENT_SAFETY_KEY = "<your-azure-content-safety-key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["azure_content_safety"]
|
||||||
|
azure_content_safety_params:
|
||||||
|
endpoint: "<your-azure-content-safety-endpoint>"
|
||||||
|
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Start the proxy, make a test request
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi, how are you?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
An HTTP 400 error will be returned if the content is detected with a value greater than the threshold set in the `config.yaml`.
|
||||||
|
The details of the response will describe :
|
||||||
|
- The `source` : input text or llm generated text
|
||||||
|
- The `category` : the category of the content that triggered the moderation
|
||||||
|
- The `severity` : the severity from 0 to 10
|
||||||
|
|
||||||
|
**Step 4**: Customizing Azure Content Safety Thresholds
|
||||||
|
|
||||||
|
You can customize the thresholds for each category by setting the `thresholds` in the `config.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["azure_content_safety"]
|
||||||
|
azure_content_safety_params:
|
||||||
|
endpoint: "<your-azure-content-safety-endpoint>"
|
||||||
|
key: "os.environ/AZURE_CONTENT_SAFETY_KEY"
|
||||||
|
thresholds:
|
||||||
|
Hate: 6
|
||||||
|
SelfHarm: 8
|
||||||
|
Sexual: 6
|
||||||
|
Violence: 4
|
||||||
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
`thresholds` are not required by default, but you can tune the values to your needs.
|
||||||
|
Default values is `4` for all categories
|
||||||
|
:::
|
|
@ -3,34 +3,38 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# ⚡ Best Practices for Production
|
# ⚡ Best Practices for Production
|
||||||
|
|
||||||
Expected Performance in Production
|
## 1. Use this config.yaml
|
||||||
|
Use this config.yaml in production (with your own LLMs)
|
||||||
|
|
||||||
1 LiteLLM Uvicorn Worker on Kubernetes
|
|
||||||
|
|
||||||
| Description | Value |
|
|
||||||
|--------------|-------|
|
|
||||||
| Avg latency | `50ms` |
|
|
||||||
| Median latency | `51ms` |
|
|
||||||
| `/chat/completions` Requests/second | `35` |
|
|
||||||
| `/chat/completions` Requests/minute | `2100` |
|
|
||||||
| `/chat/completions` Requests/hour | `126K` |
|
|
||||||
|
|
||||||
|
|
||||||
## 1. Switch off Debug Logging
|
|
||||||
|
|
||||||
Remove `set_verbose: True` from your config.yaml
|
|
||||||
```yaml
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # enter your own master key, ensure it starts with 'sk-'
|
||||||
|
alerting: ["slack"] # Setup slack alerting - get alerts on LLM exceptions, Budget Alerts, Slow LLM Responses
|
||||||
|
proxy_batch_write_at: 60 # Batch write spend updates every 60s
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
set_verbose: True
|
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||||
```
|
```
|
||||||
|
|
||||||
You should only see the following level of details in logs on the proxy server
|
Set slack webhook url in your env
|
||||||
```shell
|
```shell
|
||||||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/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"
|
||||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
|
||||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Need Help or want dedicated support ? Talk to a founder [here]: (https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
|
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
|
||||||
|
|
||||||
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
||||||
|
@ -40,21 +44,12 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
||||||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||||
```
|
```
|
||||||
|
|
||||||
## 3. Batch write spend updates every 60s
|
|
||||||
|
|
||||||
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
|
## 3. Use Redis 'port','host', 'password'. NOT 'redis_url'
|
||||||
|
|
||||||
In production, we recommend using a longer interval period of 60s. This reduces the number of connections used to make DB writes.
|
If you decide to use Redis, DO NOT use 'redis_url'. We recommend usig redis port, host, and password params.
|
||||||
|
|
||||||
```yaml
|
`redis_url`is 80 RPS slower
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
|
|
||||||
|
|
||||||
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
|
|
||||||
|
|
||||||
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
||||||
|
|
||||||
|
@ -69,103 +64,31 @@ router_settings:
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
```
|
```
|
||||||
|
|
||||||
## 5. Switch off resetting budgets
|
## Extras
|
||||||
|
### Expected Performance in Production
|
||||||
|
|
||||||
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
|
1 LiteLLM Uvicorn Worker on Kubernetes
|
||||||
```yaml
|
|
||||||
general_settings:
|
|
||||||
disable_reset_budget: true
|
|
||||||
```
|
|
||||||
|
|
||||||
## 6. Move spend logs to separate server (BETA)
|
| Description | Value |
|
||||||
|
|--------------|-------|
|
||||||
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
|
| Avg latency | `50ms` |
|
||||||
|
| Median latency | `51ms` |
|
||||||
👉 [LiteLLM Spend Logs Server](https://github.com/BerriAI/litellm/tree/main/litellm-js/spend-logs)
|
| `/chat/completions` Requests/second | `35` |
|
||||||
|
| `/chat/completions` Requests/minute | `2100` |
|
||||||
|
| `/chat/completions` Requests/hour | `126K` |
|
||||||
|
|
||||||
|
|
||||||
**Spend Logs**
|
### Verifying Debugging logs are off
|
||||||
This is a log of the key, tokens, model, and latency for each call on the proxy.
|
|
||||||
|
|
||||||
[**Full Payload**](https://github.com/BerriAI/litellm/blob/8c9623a6bc4ad9da0a2dac64249a60ed8da719e8/litellm/proxy/utils.py#L1769)
|
You should only see the following level of details in logs on the proxy server
|
||||||
|
```shell
|
||||||
|
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
**1. Start the spend logs server**
|
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
|
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
```bash
|
|
||||||
docker run -p 3000:3000 \
|
|
||||||
-e DATABASE_URL="postgres://.." \
|
|
||||||
ghcr.io/berriai/litellm-spend_logs:main-latest
|
|
||||||
|
|
||||||
# RUNNING on http://0.0.0.0:3000
|
|
||||||
```
|
|
||||||
|
|
||||||
**2. Connect to proxy**
|
|
||||||
|
|
||||||
|
|
||||||
Example litellm_config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: fake-openai-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/my-fake-model
|
|
||||||
api_key: my-fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
|
||||||
```
|
|
||||||
|
|
||||||
Add `SPEND_LOGS_URL` as an environment variable when starting the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
|
|
||||||
-e DATABASE_URL="postgresql://.." \
|
|
||||||
-e SPEND_LOGS_URL="http://host.docker.internal:3000" \ # 👈 KEY CHANGE
|
|
||||||
-p 4000:4000 \
|
|
||||||
ghcr.io/berriai/litellm:main-latest \
|
|
||||||
--config /app/config.yaml --detailed_debug
|
|
||||||
|
|
||||||
# Running on http://0.0.0.0:4000
|
|
||||||
```
|
|
||||||
|
|
||||||
**3. Test Proxy!**
|
|
||||||
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--data '{
|
|
||||||
"model": "fake-openai-endpoint",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "Be helpful"},
|
|
||||||
{"role": "user", "content": "What do you know?"}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
In your LiteLLM Spend Logs Server, you should see
|
|
||||||
|
|
||||||
**Expected Response**
|
|
||||||
|
|
||||||
```
|
|
||||||
Received and stored 1 logs. Total logs in memory: 1
|
|
||||||
...
|
|
||||||
Flushed 1 log to the DB.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Machine Specification
|
### Machine Specifications to Deploy LiteLLM
|
||||||
|
|
||||||
A t2.micro should be sufficient to handle 1k logs / minute on this server.
|
|
||||||
|
|
||||||
This consumes at max 120MB, and <0.1 vCPU.
|
|
||||||
|
|
||||||
## Machine Specifications to Deploy LiteLLM
|
|
||||||
|
|
||||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
| Service | Spec | CPUs | Memory | Architecture | Version|
|
||||||
| --- | --- | --- | --- | --- | --- |
|
| --- | --- | --- | --- | --- | --- |
|
||||||
|
@ -173,7 +96,7 @@ This consumes at max 120MB, and <0.1 vCPU.
|
||||||
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
||||||
|
|
||||||
|
|
||||||
## Reference Kubernetes Deployment YAML
|
### Reference Kubernetes Deployment YAML
|
||||||
|
|
||||||
Reference Kubernetes `deployment.yaml` that was load tested by us
|
Reference Kubernetes `deployment.yaml` that was load tested by us
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced - Context Window Fallbacks
|
## Advanced - Context Window Fallbacks (Pre-Call Checks + Fallbacks)
|
||||||
|
|
||||||
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||||
|
|
||||||
|
@ -287,6 +287,69 @@ print(response)
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced - EU-Region Filtering (Pre-Call Checks)
|
||||||
|
|
||||||
|
**Before call is made** check if a call is within model context window with **`enable_pre_call_checks: true`**.
|
||||||
|
|
||||||
|
Set 'region_name' of deployment.
|
||||||
|
|
||||||
|
**Note:** LiteLLM can automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||||
|
|
||||||
|
**1. Set Config**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
enable_pre_call_checks: true # 1. Enable pre-call checks
|
||||||
|
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
region_name: "eu" # 👈 SET EU-REGION
|
||||||
|
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo-1106
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
- model_name: gemini-pro
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/gemini-pro-1.5
|
||||||
|
vertex_project: adroit-crow-1234
|
||||||
|
vertex_location: us-east1 # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Start proxy**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.with_raw_response.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [{"role": "user", "content": "Who was Alexander?"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
print(f"response.headers.get('x-litellm-model-api-base')")
|
||||||
|
```
|
||||||
|
|
||||||
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
|
## Advanced - Custom Timeouts, Stream Timeouts - Per Model
|
||||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
|
||||||
### Step 1. Setup Proxy
|
### Step 1. Setup Proxy
|
||||||
|
|
||||||
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||||
|
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
||||||
|
@ -109,7 +110,7 @@ general_settings:
|
||||||
admin_jwt_scope: "litellm-proxy-admin"
|
admin_jwt_scope: "litellm-proxy-admin"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Advanced - Spend Tracking (User / Team / Org)
|
## Advanced - Spend Tracking (End-Users / Internal Users / Team / Org)
|
||||||
|
|
||||||
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
Set the field in the jwt token, which corresponds to a litellm user / team / org.
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ general_settings:
|
||||||
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
team_id_jwt_field: "client_id" # 👈 CAN BE ANY FIELD
|
||||||
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
user_id_jwt_field: "sub" # 👈 CAN BE ANY FIELD
|
||||||
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
org_id_jwt_field: "org_id" # 👈 CAN BE ANY FIELD
|
||||||
|
end_user_id_jwt_field: "customer_id" # 👈 CAN BE ANY FIELD
|
||||||
```
|
```
|
||||||
|
|
||||||
Expected JWT:
|
Expected JWT:
|
||||||
|
@ -130,7 +132,7 @@ Expected JWT:
|
||||||
{
|
{
|
||||||
"client_id": "my-unique-team",
|
"client_id": "my-unique-team",
|
||||||
"sub": "my-unique-user",
|
"sub": "my-unique-user",
|
||||||
"org_id": "my-unique-org"
|
"org_id": "my-unique-org",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -365,6 +365,188 @@ curl --location 'http://0.0.0.0:4000/moderations' \
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
### (BETA) Batch Completions - pass multiple models
|
||||||
|
|
||||||
|
Use this when you want to send 1 request to N Models
|
||||||
|
|
||||||
|
#### Expected Request Format
|
||||||
|
|
||||||
|
Pass model as a string of comma separated value of models. Example `"model"="llama3,gpt-3.5-turbo"`
|
||||||
|
|
||||||
|
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
|
||||||
|
- `model_name="llama3"`
|
||||||
|
- `model_name="gpt-3.5-turbo"`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai-py" label="OpenAI Python SDK">
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo,llama3",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "this is a test request, write a short poem"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### Expected Response Format
|
||||||
|
|
||||||
|
Get a list of responses when `model` is passed as a list
|
||||||
|
|
||||||
|
```python
|
||||||
|
[
|
||||||
|
ChatCompletion(
|
||||||
|
id='chatcmpl-9NoYhS2G0fswot0b6QpoQgmRQMaIf',
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason='stop',
|
||||||
|
index=0,
|
||||||
|
logprobs=None,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content='In the depths of my soul, a spark ignites\nA light that shines so pure and bright\nIt dances and leaps, refusing to die\nA flame of hope that reaches the sky\n\nIt warms my heart and fills me with bliss\nA reminder that in darkness, there is light to kiss\nSo I hold onto this fire, this guiding light\nAnd let it lead me through the darkest night.',
|
||||||
|
role='assistant',
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1715462919,
|
||||||
|
model='gpt-3.5-turbo-0125',
|
||||||
|
object='chat.completion',
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=83,
|
||||||
|
prompt_tokens=17,
|
||||||
|
total_tokens=100
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ChatCompletion(
|
||||||
|
id='chatcmpl-4ac3e982-da4e-486d-bddb-ed1d5cb9c03c',
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason='stop',
|
||||||
|
index=0,
|
||||||
|
logprobs=None,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content="A test request, and I'm delighted!\nHere's a short poem, just for you:\n\nMoonbeams dance upon the sea,\nA path of light, for you to see.\nThe stars up high, a twinkling show,\nA night of wonder, for all to know.\n\nThe world is quiet, save the night,\nA peaceful hush, a gentle light.\nThe world is full, of beauty rare,\nA treasure trove, beyond compare.\n\nI hope you enjoyed this little test,\nA poem born, of whimsy and jest.\nLet me know, if there's anything else!",
|
||||||
|
role='assistant',
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1715462919,
|
||||||
|
model='groq/llama3-8b-8192',
|
||||||
|
object='chat.completion',
|
||||||
|
system_fingerprint='fp_a2c8d063cb',
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=120,
|
||||||
|
prompt_tokens=20,
|
||||||
|
total_tokens=140
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama3,gpt-3.5-turbo",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"user": "litellm2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "is litellm getting better"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### Expected Response Format
|
||||||
|
|
||||||
|
Get a list of responses when `model` is passed as a list
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1715459876,
|
||||||
|
"model": "groq/llama3-8b-8192",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "fp_179b0f92c9",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 12,
|
||||||
|
"total_tokens": 22
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "TES4 could refer to The Elder Scrolls IV:",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1715459877,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": null,
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"total_tokens": 19
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Pass User LLM API Keys, Fallbacks
|
### Pass User LLM API Keys, Fallbacks
|
||||||
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ Requirements:
|
||||||
|
|
||||||
You can set budgets at 3 levels:
|
You can set budgets at 3 levels:
|
||||||
- For the proxy
|
- For the proxy
|
||||||
- For a user
|
- For an internal user
|
||||||
- For a 'user' passed to `/chat/completions`, `/embeddings` etc
|
- For an end-user
|
||||||
- For a key
|
- For a key
|
||||||
- For a key (model specific budgets)
|
- For a key (model specific budgets)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="per-user" label="For User">
|
<TabItem value="per-user" label="For Internal User">
|
||||||
|
|
||||||
Apply a budget across multiple keys.
|
Apply a budget across multiple keys.
|
||||||
|
|
||||||
|
@ -165,12 +165,12 @@ curl --location 'http://localhost:4000/team/new' \
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="per-user-chat" label="For 'user' passed to /chat/completions">
|
<TabItem value="per-user-chat" label="For End User">
|
||||||
|
|
||||||
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
|
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
|
||||||
|
|
||||||
**Step 1. Modify config.yaml**
|
**Step 1. Modify config.yaml**
|
||||||
Define `litellm.max_user_budget`
|
Define `litellm.max_end_user_budget`
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
@ -328,7 +328,7 @@ You can set:
|
||||||
- max parallel requests
|
- max parallel requests
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="per-user" label="Per User">
|
<TabItem value="per-user" label="Per Internal User">
|
||||||
|
|
||||||
Use `/user/new`, to persist rate limits across multiple keys.
|
Use `/user/new`, to persist rate limits across multiple keys.
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ curl --location 'http://localhost:4000/user/new' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Create new keys for existing user
|
## Create new keys for existing internal user
|
||||||
|
|
||||||
Just include user_id in the `/key/generate` request.
|
Just include user_id in the `/key/generate` request.
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ print(response)
|
||||||
- `router.aimage_generation()` - async image generation calls
|
- `router.aimage_generation()` - async image generation calls
|
||||||
|
|
||||||
## Advanced - Routing Strategies
|
## Advanced - Routing Strategies
|
||||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based, Cost Based
|
||||||
|
|
||||||
Router provides 4 strategies for routing your calls across multiple deployments:
|
Router provides 4 strategies for routing your calls across multiple deployments:
|
||||||
|
|
||||||
|
@ -467,6 +467,101 @@ async def router_acompletion():
|
||||||
asyncio.run(router_acompletion())
|
asyncio.run(router_acompletion())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
|
||||||
|
|
||||||
|
Picks a deployment based on the lowest cost
|
||||||
|
|
||||||
|
How this works:
|
||||||
|
- Get all healthy deployments
|
||||||
|
- Select all deployments that are under their provided `rpm/tpm` limits
|
||||||
|
- For each deployment check if `litellm_param["model"]` exists in [`litellm_model_cost_map`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||||
|
- if deployment does not exist in `litellm_model_cost_map` -> use deployment_cost= `$1`
|
||||||
|
- Select deployment with lowest cost
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-4"},
|
||||||
|
"model_info": {"id": "openai-gpt-4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "groq/llama3-8b-8192"},
|
||||||
|
"model_info": {"id": "groq-llama"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# init router
|
||||||
|
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||||
|
async def router_acompletion():
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
print(response._hidden_params["model_id"]) # expect groq-llama, since groq/llama has lowest cost
|
||||||
|
return response
|
||||||
|
|
||||||
|
asyncio.run(router_acompletion())
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Using Custom Input/Output pricing
|
||||||
|
|
||||||
|
Set `litellm_params["input_cost_per_token"]` and `litellm_params["output_cost_per_token"]` for using custom pricing when routing
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"input_cost_per_token": 0.00003,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-experimental"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-1",
|
||||||
|
"input_cost_per_token": 0.000000001,
|
||||||
|
"output_cost_per_token": 0.00000001,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-5",
|
||||||
|
"input_cost_per_token": 10,
|
||||||
|
"output_cost_per_token": 12,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-5"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
# init router
|
||||||
|
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||||
|
async def router_acompletion():
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
print(response._hidden_params["model_id"]) # expect chatgpt-v-1, since chatgpt-v-1 has lowest cost
|
||||||
|
return response
|
||||||
|
|
||||||
|
asyncio.run(router_acompletion())
|
||||||
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
@ -558,7 +653,9 @@ from litellm import Router
|
||||||
model_list = [{...}]
|
model_list = [{...}]
|
||||||
|
|
||||||
router = Router(model_list=model_list,
|
router = Router(model_list=model_list,
|
||||||
allowed_fails=1) # cooldown model if it fails > 1 call in a minute.
|
allowed_fails=1, # cooldown model if it fails > 1 call in a minute.
|
||||||
|
cooldown_time=100 # cooldown the deployment for 100 seconds if it num_fails > allowed_fails
|
||||||
|
)
|
||||||
|
|
||||||
user_message = "Hello, whats the weather in San Francisco??"
|
user_message = "Hello, whats the weather in San Francisco??"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
@ -616,6 +713,57 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Retries based on Error Type
|
||||||
|
|
||||||
|
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- 4 retries for `ContentPolicyViolationError`
|
||||||
|
- 0 retries for `RateLimitErrors`
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.router import RetryPolicy
|
||||||
|
retry_policy = RetryPolicy(
|
||||||
|
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||||
|
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||||
|
BadRequestErrorRetries=1,
|
||||||
|
TimeoutErrorRetries=2,
|
||||||
|
RateLimitErrorRetries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "bad-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Fallbacks
|
### Fallbacks
|
||||||
|
|
||||||
If a call fails after num_retries, fall back to another model group.
|
If a call fails after num_retries, fall back to another model group.
|
||||||
|
@ -624,6 +772,8 @@ If the error is a context window exceeded error, fall back to a larger model gro
|
||||||
|
|
||||||
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
|
Fallbacks are done in-order - ["gpt-3.5-turbo, "gpt-4", "gpt-4-32k"], will do 'gpt-3.5-turbo' first, then 'gpt-4', etc.
|
||||||
|
|
||||||
|
You can also set 'default_fallbacks', in case a specific model group is misconfigured / bad.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
|
@ -684,6 +834,7 @@ model_list = [
|
||||||
|
|
||||||
router = Router(model_list=model_list,
|
router = Router(model_list=model_list,
|
||||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||||
|
default_fallbacks=["gpt-3.5-turbo-16k"],
|
||||||
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
||||||
set_verbose=True)
|
set_verbose=True)
|
||||||
|
|
||||||
|
@ -733,13 +884,11 @@ router = Router(model_list: Optional[list] = None,
|
||||||
cache_responses=True)
|
cache_responses=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Pre-Call Checks (Context Window)
|
## Pre-Call Checks (Context Window, EU-Regions)
|
||||||
|
|
||||||
Enable pre-call checks to filter out:
|
Enable pre-call checks to filter out:
|
||||||
1. deployments with context window limit < messages for a call.
|
1. deployments with context window limit < messages for a call.
|
||||||
2. deployments that have exceeded rate limits when making concurrent calls. (eg. `asyncio.gather(*[
|
2. deployments outside of eu-region
|
||||||
router.acompletion(model="gpt-3.5-turbo", messages=m) for m in list_of_messages
|
|
||||||
])`)
|
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
@ -754,10 +903,14 @@ router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Set t
|
||||||
|
|
||||||
**2. Set Model List**
|
**2. Set Model List**
|
||||||
|
|
||||||
For azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
For context window checks on azure deployments, set the base model. Pick the base model from [this list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json), all the azure models start with `azure/`.
|
||||||
|
|
||||||
<Tabs>
|
For 'eu-region' filtering, Set 'region_name' of deployment.
|
||||||
<TabItem value="same-group" label="Same Group">
|
|
||||||
|
**Note:** We automatically infer region_name for Vertex AI, Bedrock, and IBM WatsonxAI based on your litellm params. For Azure, set `litellm.enable_preview = True`.
|
||||||
|
|
||||||
|
|
||||||
|
[**See Code**](https://github.com/BerriAI/litellm/blob/d33e49411d6503cb634f9652873160cd534dec96/litellm/router.py#L2958)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
model_list = [
|
model_list = [
|
||||||
|
@ -768,10 +921,9 @@ model_list = [
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
},
|
"region_name": "eu" # 👈 SET 'EU' REGION NAME
|
||||||
"model_info": {
|
|
||||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo", # model group name
|
"model_name": "gpt-3.5-turbo", # model group name
|
||||||
|
@ -780,54 +932,26 @@ model_list = [
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gemini-pro",
|
||||||
|
"litellm_params: {
|
||||||
|
"model": "vertex_ai/gemini-pro-1.5",
|
||||||
|
"vertex_project": "adroit-crow-1234",
|
||||||
|
"vertex_location": "us-east1" # 👈 AUTOMATICALLY INFERS 'region_name'
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
<TabItem value="different-group" label="Context Window Fallbacks (Different Groups)">
|
|
||||||
|
|
||||||
```python
|
|
||||||
model_list = [
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo-small", # model group name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
|
||||||
},
|
|
||||||
"model_info": {
|
|
||||||
"base_model": "azure/gpt-35-turbo", # 👈 (Azure-only) SET BASE MODEL
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo-large", # model group name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "gpt-3.5-turbo-1106",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "claude-opus",
|
|
||||||
"litellm_params": { call
|
|
||||||
"model": "claude-3-opus-20240229",
|
|
||||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
router = Router(model_list=model_list, enable_pre_call_checks=True, context_window_fallbacks=[{"gpt-3.5-turbo-small": ["gpt-3.5-turbo-large", "claude-opus"]}])
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
**3. Test it!**
|
**3. Test it!**
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="context-window-check" label="Context Window Check">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
"""
|
"""
|
||||||
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
- Give a gpt-3.5-turbo model group with different context windows (4k vs. 16k)
|
||||||
|
@ -837,7 +961,6 @@ router = Router(model_list=model_list, enable_pre_call_checks=True, context_wind
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
import os
|
import os
|
||||||
|
|
||||||
try:
|
|
||||||
model_list = [
|
model_list = [
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo", # model group name
|
"model_name": "gpt-3.5-turbo", # model group name
|
||||||
|
@ -846,6 +969,7 @@ model_list = [
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"base_model": "azure/gpt-35-turbo",
|
||||||
},
|
},
|
||||||
"model_info": {
|
"model_info": {
|
||||||
"base_model": "azure/gpt-35-turbo",
|
"base_model": "azure/gpt-35-turbo",
|
||||||
|
@ -875,6 +999,59 @@ response = router.completion(
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
<TabItem value="eu-region-check" label="EU Region Check">
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
- Give 2 gpt-3.5-turbo deployments, in eu + non-eu regions
|
||||||
|
- Make a call
|
||||||
|
- Assert it picks the eu-region model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from litellm import Router
|
||||||
|
import os
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # model group name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"region_name": "eu"
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # model group name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-1106",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
router = Router(model_list=model_list, enable_pre_call_checks=True)
|
||||||
|
|
||||||
|
response = router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Who was Alexander?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
print(f"response id: {response._hidden_params['model_id']}")
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
<TabItem value="proxy" label="Proxy">
|
<TabItem value="proxy" label="Proxy">
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
@ -940,6 +1117,46 @@ async def test_acompletion_caching_on_router_caching_groups():
|
||||||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Alerting 🚨
|
||||||
|
|
||||||
|
Send alerts to slack / your webhook url for the following events
|
||||||
|
- LLM API Exceptions
|
||||||
|
- Slow LLM Responses
|
||||||
|
|
||||||
|
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.router import AlertingConfig
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "bad_key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
alerting_config= AlertingConfig(
|
||||||
|
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
|
||||||
|
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
## Track cost for Azure Deployments
|
## Track cost for Azure Deployments
|
||||||
|
|
||||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
@ -1097,10 +1314,11 @@ def __init__(
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
default_litellm_params={}, # default params for Router.chat.completion.create
|
default_litellm_params={}, # default params for Router.chat.completion.create
|
||||||
fallbacks: List = [],
|
fallbacks: Optional[List] = None,
|
||||||
|
default_fallbacks: Optional[List] = None
|
||||||
allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown
|
allowed_fails: Optional[int] = None, # Number of times a deployment can failbefore being added to cooldown
|
||||||
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
|
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
|
||||||
context_window_fallbacks: List = [],
|
context_window_fallbacks: Optional[List] = None,
|
||||||
model_group_alias: Optional[dict] = {},
|
model_group_alias: Optional[dict] = {},
|
||||||
retry_after: int = 0, # (min) time to wait before retrying a failed request
|
retry_after: int = 0, # (min) time to wait before retrying a failed request
|
||||||
routing_strategy: Literal[
|
routing_strategy: Literal[
|
||||||
|
@ -1108,6 +1326,7 @@ def __init__(
|
||||||
"least-busy",
|
"least-busy",
|
||||||
"usage-based-routing",
|
"usage-based-routing",
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
|
"cost-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
|
|
||||||
## DEBUGGING ##
|
## DEBUGGING ##
|
||||||
|
|
|
@ -39,6 +39,7 @@ const sidebars = {
|
||||||
"proxy/demo",
|
"proxy/demo",
|
||||||
"proxy/configs",
|
"proxy/configs",
|
||||||
"proxy/reliability",
|
"proxy/reliability",
|
||||||
|
"proxy/cost_tracking",
|
||||||
"proxy/users",
|
"proxy/users",
|
||||||
"proxy/user_keys",
|
"proxy/user_keys",
|
||||||
"proxy/enterprise",
|
"proxy/enterprise",
|
||||||
|
@ -50,8 +51,8 @@ const sidebars = {
|
||||||
items: ["proxy/logging", "proxy/streaming_logging"],
|
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||||
},
|
},
|
||||||
"proxy/team_based_routing",
|
"proxy/team_based_routing",
|
||||||
|
"proxy/customer_routing",
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/cost_tracking",
|
|
||||||
"proxy/token_auth",
|
"proxy/token_auth",
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
@ -131,9 +132,13 @@ const sidebars = {
|
||||||
"providers/cohere",
|
"providers/cohere",
|
||||||
"providers/anyscale",
|
"providers/anyscale",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
|
"providers/watsonx",
|
||||||
|
"providers/predibase",
|
||||||
|
"providers/triton-inference-server",
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/perplexity",
|
"providers/perplexity",
|
||||||
"providers/groq",
|
"providers/groq",
|
||||||
|
"providers/deepseek",
|
||||||
"providers/fireworks_ai",
|
"providers/fireworks_ai",
|
||||||
"providers/vllm",
|
"providers/vllm",
|
||||||
"providers/xinference",
|
"providers/xinference",
|
||||||
|
@ -149,7 +154,7 @@ const sidebars = {
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/custom_openai_proxy",
|
"providers/custom_openai_proxy",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
"providers/watsonx",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
|
|
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +18,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Enterprise Proxy Util Endpoints
|
# Enterprise Proxy Util Endpoints
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
import collections
|
import collections
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||||
|
@ -18,26 +19,33 @@ async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
async def ui_get_spend_by_tags(start_date: str, end_date: str, prisma_client):
|
||||||
response = await prisma_client.db.query_raw(
|
|
||||||
"""
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||||
DATE(s."startTime") AS spend_date,
|
DATE(s."startTime") AS spend_date,
|
||||||
COUNT(*) AS log_count,
|
COUNT(*) AS log_count,
|
||||||
SUM(spend) AS total_spend
|
SUM(spend) AS total_spend
|
||||||
FROM "LiteLLM_SpendLogs" s
|
FROM "LiteLLM_SpendLogs" s
|
||||||
WHERE s."startTime" >= current_date - interval '30 days'
|
WHERE
|
||||||
|
DATE(s."startTime") >= $1::date
|
||||||
|
AND DATE(s."startTime") <= $2::date
|
||||||
GROUP BY individual_request_tag, spend_date
|
GROUP BY individual_request_tag, spend_date
|
||||||
ORDER BY spend_date;
|
ORDER BY spend_date
|
||||||
|
LIMIT 100;
|
||||||
"""
|
"""
|
||||||
|
response = await prisma_client.db.query_raw(
|
||||||
|
sql_query,
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("tags - spend")
|
# print("tags - spend")
|
||||||
# print(response)
|
# print(response)
|
||||||
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
|
# Bar Chart 1 - Spend per tag - Top 10 tags by spend
|
||||||
total_spend_per_tag = collections.defaultdict(float)
|
total_spend_per_tag: collections.defaultdict = collections.defaultdict(float)
|
||||||
total_requests_per_tag = collections.defaultdict(int)
|
total_requests_per_tag: collections.defaultdict = collections.defaultdict(int)
|
||||||
for row in response:
|
for row in response:
|
||||||
tag_name = row["individual_request_tag"]
|
tag_name = row["individual_request_tag"]
|
||||||
tag_spend = row["total_spend"]
|
tag_spend = row["total_spend"]
|
||||||
|
@ -49,15 +57,18 @@ async def ui_get_spend_by_tags(start_date=None, end_date=None, prisma_client=Non
|
||||||
# convert to ui format
|
# convert to ui format
|
||||||
ui_tags = []
|
ui_tags = []
|
||||||
for tag in sorted_tags:
|
for tag in sorted_tags:
|
||||||
|
current_spend = tag[1]
|
||||||
|
if current_spend is not None and isinstance(current_spend, float):
|
||||||
|
current_spend = round(current_spend, 4)
|
||||||
ui_tags.append(
|
ui_tags.append(
|
||||||
{
|
{
|
||||||
"name": tag[0],
|
"name": tag[0],
|
||||||
"value": tag[1],
|
"spend": current_spend,
|
||||||
"log_count": total_requests_per_tag[tag[0]],
|
"log_count": total_requests_per_tag[tag[0]],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"top_10_tags": ui_tags}
|
return {"spend_per_tag": ui_tags}
|
||||||
|
|
||||||
|
|
||||||
async def view_spend_logs_from_clickhouse(
|
async def view_spend_logs_from_clickhouse(
|
||||||
|
@ -291,7 +302,7 @@ def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
|
||||||
|
|
||||||
|
|
||||||
def _forecast_daily_cost(data: list):
|
def _forecast_daily_cost(data: list):
|
||||||
import requests
|
import requests # type: ignore
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
|
|
108
index.yaml
Normal file
108
index.yaml
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
apiVersion: v1
|
||||||
|
entries:
|
||||||
|
litellm-helm:
|
||||||
|
- apiVersion: v2
|
||||||
|
appVersion: v1.35.38
|
||||||
|
created: "2024-05-06T10:22:24.384392-07:00"
|
||||||
|
dependencies:
|
||||||
|
- condition: db.deployStandalone
|
||||||
|
name: postgresql
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
version: '>=13.3.0'
|
||||||
|
- condition: redis.enabled
|
||||||
|
name: redis
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
version: '>=18.0.0'
|
||||||
|
description: Call all LLM APIs using the OpenAI format
|
||||||
|
digest: 60f0cfe9e7c1087437cb35f6fb7c43c3ab2be557b6d3aec8295381eb0dfa760f
|
||||||
|
name: litellm-helm
|
||||||
|
type: application
|
||||||
|
urls:
|
||||||
|
- litellm-helm-0.2.0.tgz
|
||||||
|
version: 0.2.0
|
||||||
|
postgresql:
|
||||||
|
- annotations:
|
||||||
|
category: Database
|
||||||
|
images: |
|
||||||
|
- name: os-shell
|
||||||
|
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||||
|
- name: postgres-exporter
|
||||||
|
image: docker.io/bitnami/postgres-exporter:0.15.0-debian-12-r14
|
||||||
|
- name: postgresql
|
||||||
|
image: docker.io/bitnami/postgresql:16.2.0-debian-12-r6
|
||||||
|
licenses: Apache-2.0
|
||||||
|
apiVersion: v2
|
||||||
|
appVersion: 16.2.0
|
||||||
|
created: "2024-05-06T10:22:24.387717-07:00"
|
||||||
|
dependencies:
|
||||||
|
- name: common
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
tags:
|
||||||
|
- bitnami-common
|
||||||
|
version: 2.x.x
|
||||||
|
description: PostgreSQL (Postgres) is an open source object-relational database
|
||||||
|
known for reliability and data integrity. ACID-compliant, it supports foreign
|
||||||
|
keys, joins, views, triggers and stored procedures.
|
||||||
|
digest: 3c8125526b06833df32e2f626db34aeaedb29d38f03d15349db6604027d4a167
|
||||||
|
home: https://bitnami.com
|
||||||
|
icon: https://bitnami.com/assets/stacks/postgresql/img/postgresql-stack-220x234.png
|
||||||
|
keywords:
|
||||||
|
- postgresql
|
||||||
|
- postgres
|
||||||
|
- database
|
||||||
|
- sql
|
||||||
|
- replication
|
||||||
|
- cluster
|
||||||
|
maintainers:
|
||||||
|
- name: VMware, Inc.
|
||||||
|
url: https://github.com/bitnami/charts
|
||||||
|
name: postgresql
|
||||||
|
sources:
|
||||||
|
- https://github.com/bitnami/charts/tree/main/bitnami/postgresql
|
||||||
|
urls:
|
||||||
|
- charts/postgresql-14.3.1.tgz
|
||||||
|
version: 14.3.1
|
||||||
|
redis:
|
||||||
|
- annotations:
|
||||||
|
category: Database
|
||||||
|
images: |
|
||||||
|
- name: kubectl
|
||||||
|
image: docker.io/bitnami/kubectl:1.29.2-debian-12-r3
|
||||||
|
- name: os-shell
|
||||||
|
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||||
|
- name: redis
|
||||||
|
image: docker.io/bitnami/redis:7.2.4-debian-12-r9
|
||||||
|
- name: redis-exporter
|
||||||
|
image: docker.io/bitnami/redis-exporter:1.58.0-debian-12-r4
|
||||||
|
- name: redis-sentinel
|
||||||
|
image: docker.io/bitnami/redis-sentinel:7.2.4-debian-12-r7
|
||||||
|
licenses: Apache-2.0
|
||||||
|
apiVersion: v2
|
||||||
|
appVersion: 7.2.4
|
||||||
|
created: "2024-05-06T10:22:24.391903-07:00"
|
||||||
|
dependencies:
|
||||||
|
- name: common
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
tags:
|
||||||
|
- bitnami-common
|
||||||
|
version: 2.x.x
|
||||||
|
description: Redis(R) is an open source, advanced key-value store. It is often
|
||||||
|
referred to as a data structure server since keys can contain strings, hashes,
|
||||||
|
lists, sets and sorted sets.
|
||||||
|
digest: b2fa1835f673a18002ca864c54fadac3c33789b26f6c5e58e2851b0b14a8f984
|
||||||
|
home: https://bitnami.com
|
||||||
|
icon: https://bitnami.com/assets/stacks/redis/img/redis-stack-220x234.png
|
||||||
|
keywords:
|
||||||
|
- redis
|
||||||
|
- keyvalue
|
||||||
|
- database
|
||||||
|
maintainers:
|
||||||
|
- name: VMware, Inc.
|
||||||
|
url: https://github.com/bitnami/charts
|
||||||
|
name: redis
|
||||||
|
sources:
|
||||||
|
- https://github.com/bitnami/charts/tree/main/bitnami/redis
|
||||||
|
urls:
|
||||||
|
- charts/redis-18.19.1.tgz
|
||||||
|
version: 18.19.1
|
||||||
|
generated: "2024-05-06T10:22:24.375026-07:00"
|
BIN
litellm-helm-0.2.0.tgz
Normal file
BIN
litellm-helm-0.2.0.tgz
Normal file
Binary file not shown.
|
@ -1,3 +1,7 @@
|
||||||
|
### Hide pydantic namespace conflict warnings globally ###
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
|
@ -67,13 +71,16 @@ azure_key: Optional[str] = None
|
||||||
anthropic_key: Optional[str] = None
|
anthropic_key: Optional[str] = None
|
||||||
replicate_key: Optional[str] = None
|
replicate_key: Optional[str] = None
|
||||||
cohere_key: Optional[str] = None
|
cohere_key: Optional[str] = None
|
||||||
|
clarifai_key: Optional[str] = None
|
||||||
maritalk_key: Optional[str] = None
|
maritalk_key: Optional[str] = None
|
||||||
ai21_key: Optional[str] = None
|
ai21_key: Optional[str] = None
|
||||||
ollama_key: Optional[str] = None
|
ollama_key: Optional[str] = None
|
||||||
openrouter_key: Optional[str] = None
|
openrouter_key: Optional[str] = None
|
||||||
|
predibase_key: Optional[str] = None
|
||||||
huggingface_key: Optional[str] = None
|
huggingface_key: Optional[str] = None
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
|
predibase_tenant_id: Optional[str] = None
|
||||||
togetherai_api_key: Optional[str] = None
|
togetherai_api_key: Optional[str] = None
|
||||||
cloudflare_api_key: Optional[str] = None
|
cloudflare_api_key: Optional[str] = None
|
||||||
baseten_key: Optional[str] = None
|
baseten_key: Optional[str] = None
|
||||||
|
@ -95,6 +102,9 @@ blocked_user_list: Optional[Union[str, List]] = None
|
||||||
banned_keywords_list: Optional[Union[str, List]] = None
|
banned_keywords_list: Optional[Union[str, List]] = None
|
||||||
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
|
### PREVIEW FEATURES ###
|
||||||
|
enable_preview_features: bool = False
|
||||||
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
|
@ -361,6 +371,7 @@ openai_compatible_endpoints: List = [
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
"api.groq.com/openai/v1",
|
"api.groq.com/openai/v1",
|
||||||
|
"api.deepseek.com/v1",
|
||||||
"api.together.xyz/v1",
|
"api.together.xyz/v1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -369,6 +380,7 @@ openai_compatible_providers: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"xinference",
|
"xinference",
|
||||||
|
@ -393,6 +405,73 @@ replicate_models: List = [
|
||||||
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
clarifai_models: List = [
|
||||||
|
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
|
||||||
|
"clarifai/gcp.generate.gemma-1_1-7b-it",
|
||||||
|
"clarifai/mistralai.completion.mixtral-8x22B",
|
||||||
|
"clarifai/cohere.generate.command-r-plus",
|
||||||
|
"clarifai/databricks.drbx.dbrx-instruct",
|
||||||
|
"clarifai/mistralai.completion.mistral-large",
|
||||||
|
"clarifai/mistralai.completion.mistral-medium",
|
||||||
|
"clarifai/mistralai.completion.mistral-small",
|
||||||
|
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
|
||||||
|
"clarifai/gcp.generate.gemma-2b-it",
|
||||||
|
"clarifai/gcp.generate.gemma-7b-it",
|
||||||
|
"clarifai/deci.decilm.deciLM-7B-instruct",
|
||||||
|
"clarifai/mistralai.completion.mistral-7B-Instruct",
|
||||||
|
"clarifai/gcp.generate.gemini-pro",
|
||||||
|
"clarifai/anthropic.completion.claude-v1",
|
||||||
|
"clarifai/anthropic.completion.claude-instant-1_2",
|
||||||
|
"clarifai/anthropic.completion.claude-instant",
|
||||||
|
"clarifai/anthropic.completion.claude-v2",
|
||||||
|
"clarifai/anthropic.completion.claude-2_1",
|
||||||
|
"clarifai/meta.Llama-2.codeLlama-70b-Python",
|
||||||
|
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
|
||||||
|
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
|
||||||
|
"clarifai/meta.Llama-2.llama2-7b-chat",
|
||||||
|
"clarifai/meta.Llama-2.llama2-13b-chat",
|
||||||
|
"clarifai/meta.Llama-2.llama2-70b-chat",
|
||||||
|
"clarifai/openai.chat-completion.gpt-4-turbo",
|
||||||
|
"clarifai/microsoft.text-generation.phi-2",
|
||||||
|
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
|
||||||
|
"clarifai/upstage.solar.solar-10_7b-instruct",
|
||||||
|
"clarifai/openchat.openchat.openchat-3_5-1210",
|
||||||
|
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
|
||||||
|
"clarifai/gcp.generate.text-bison",
|
||||||
|
"clarifai/meta.Llama-2.llamaGuard-7b",
|
||||||
|
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
|
||||||
|
"clarifai/openai.chat-completion.GPT-4",
|
||||||
|
"clarifai/openai.chat-completion.GPT-3_5-turbo",
|
||||||
|
"clarifai/ai21.complete.Jurassic2-Grande",
|
||||||
|
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
|
||||||
|
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
|
||||||
|
"clarifai/ai21.complete.Jurassic2-Jumbo",
|
||||||
|
"clarifai/ai21.complete.Jurassic2-Large",
|
||||||
|
"clarifai/cohere.generate.cohere-generate-command",
|
||||||
|
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
|
||||||
|
"clarifai/wizardlm.generate.wizardLM-70B",
|
||||||
|
"clarifai/tiiuae.falcon.falcon-40b-instruct",
|
||||||
|
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
|
||||||
|
"clarifai/gcp.generate.code-gecko",
|
||||||
|
"clarifai/gcp.generate.code-bison",
|
||||||
|
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
|
||||||
|
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
|
||||||
|
"clarifai/wizardlm.generate.wizardLM-13B",
|
||||||
|
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
|
||||||
|
"clarifai/wizardlm.generate.wizardCoder-15B",
|
||||||
|
"clarifai/microsoft.text-generation.phi-1_5",
|
||||||
|
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
|
||||||
|
"clarifai/bigcode.code.StarCoder",
|
||||||
|
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
|
||||||
|
"clarifai/mosaicml.mpt.mpt-7b-instruct",
|
||||||
|
"clarifai/anthropic.completion.claude-3-opus",
|
||||||
|
"clarifai/anthropic.completion.claude-3-sonnet",
|
||||||
|
"clarifai/gcp.generate.gemini-1_5-pro",
|
||||||
|
"clarifai/gcp.generate.imagen-2",
|
||||||
|
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
huggingface_models: List = [
|
huggingface_models: List = [
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
"meta-llama/Llama-2-7b-chat-hf",
|
"meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
@ -498,6 +577,7 @@ provider_list: List = [
|
||||||
"text-completion-openai",
|
"text-completion-openai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"cohere_chat",
|
"cohere_chat",
|
||||||
|
"clarifai",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"replicate",
|
"replicate",
|
||||||
"huggingface",
|
"huggingface",
|
||||||
|
@ -523,12 +603,15 @@ provider_list: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"voyage",
|
"voyage",
|
||||||
"cloudflare",
|
"cloudflare",
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
"triton",
|
||||||
|
"predibase",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -605,7 +688,6 @@ all_embedding_models = (
|
||||||
####### IMAGE GENERATION MODELS ###################
|
####### IMAGE GENERATION MODELS ###################
|
||||||
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||||
|
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .utils import (
|
from .utils import (
|
||||||
client,
|
client,
|
||||||
|
@ -638,12 +720,15 @@ from .utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_api_base,
|
get_api_base,
|
||||||
|
get_first_chars_messages,
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
from .llms.predibase import PredibaseConfig
|
||||||
from .llms.anthropic_text import AnthropicTextConfig
|
from .llms.anthropic_text import AnthropicTextConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere import CohereConfig
|
from .llms.cohere import CohereConfig
|
||||||
|
from .llms.clarifai import ClarifaiConfig
|
||||||
from .llms.ai21 import AI21Config
|
from .llms.ai21 import AI21Config
|
||||||
from .llms.together_ai import TogetherAIConfig
|
from .llms.together_ai import TogetherAIConfig
|
||||||
from .llms.cloudflare import CloudflareConfig
|
from .llms.cloudflare import CloudflareConfig
|
||||||
|
@ -658,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
|
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
@ -669,7 +755,7 @@ from .llms.bedrock import (
|
||||||
AmazonMistralConfig,
|
AmazonMistralConfig,
|
||||||
AmazonBedrockGlobalConfig,
|
AmazonBedrockGlobalConfig,
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, MistralConfig
|
||||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||||
from .llms.watsonx import IBMWatsonXAIConfig
|
from .llms.watsonx import IBMWatsonXAIConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
|
@ -694,3 +780,4 @@ from .exceptions import (
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
from .router import Router
|
from .router import Router
|
||||||
|
from .assistants.main import *
|
||||||
|
|
|
@ -10,8 +10,8 @@
|
||||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
import redis, litellm
|
import redis, litellm # type: ignore
|
||||||
import redis.asyncio as async_redis
|
import redis.asyncio as async_redis # type: ignore
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
|
495
litellm/assistants/main.py
Normal file
495
litellm/assistants/main.py
Normal file
|
@ -0,0 +1,495 @@
|
||||||
|
# What is this?
|
||||||
|
## Main file for assistants API logic
|
||||||
|
from typing import Iterable
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
from openai import OpenAI
|
||||||
|
from litellm import client
|
||||||
|
from litellm.utils import supports_httpx_timeout
|
||||||
|
from ..llms.openai import OpenAIAssistantsAPI
|
||||||
|
from ..types.llms.openai import *
|
||||||
|
from ..types.router import *
|
||||||
|
|
||||||
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
|
openai_assistants_api = OpenAIAssistantsAPI()
|
||||||
|
|
||||||
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
|
||||||
|
def get_assistants(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[SyncCursorPage[Assistant]] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_assistants(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### THREADS ###
|
||||||
|
|
||||||
|
|
||||||
|
def create_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
"""
|
||||||
|
- get the llm provider
|
||||||
|
- if openai - route it there
|
||||||
|
- pass through relevant params
|
||||||
|
|
||||||
|
```
|
||||||
|
from litellm import create_thread
|
||||||
|
|
||||||
|
create_thread(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
### OPTIONAL ###
|
||||||
|
messages = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is AI?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How does AI work? Explain it in simple terms."
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Thread] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.create_thread(
|
||||||
|
messages=messages,
|
||||||
|
metadata=metadata,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
"""Get the thread object, given a thread_id"""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Thread] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### MESSAGES ###
|
||||||
|
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
role: Literal["user", "assistant"],
|
||||||
|
content: str,
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
### COMMON OBJECTS ###
|
||||||
|
message_data = MessageData(
|
||||||
|
role=role, content=content, attachments=attachments, metadata=metadata
|
||||||
|
)
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[OpenAIMessage] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_data=message_data,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_messages(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### RUNS ###
|
||||||
|
|
||||||
|
|
||||||
|
def run_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str] = None,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Run:
|
||||||
|
"""Run a given thread + assistant."""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Run] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.run_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
|
@ -10,7 +10,7 @@
|
||||||
import os, json, time
|
import os, json, time
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
import requests, threading
|
import requests, threading # type: ignore
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ class InMemoryCache(BaseCache):
|
||||||
return_val.append(val)
|
return_val.append(val)
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||||
# get the value
|
# get the value
|
||||||
init_value = await self.async_get_cache(key=key) or 0
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
value = init_value + value
|
value = init_value + value
|
||||||
|
@ -177,11 +177,18 @@ class RedisCache(BaseCache):
|
||||||
try:
|
try:
|
||||||
# asyncio.get_running_loop().create_task(self.ping())
|
# asyncio.get_running_loop().create_task(self.ping())
|
||||||
result = asyncio.get_running_loop().create_task(self.ping())
|
result = asyncio.get_running_loop().create_task(self.ping())
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
verbose_logger.error(
|
||||||
|
"Error connecting to Async Redis client", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
### SYNC HEALTH PING ###
|
### SYNC HEALTH PING ###
|
||||||
|
try:
|
||||||
self.redis_client.ping()
|
self.redis_client.ping()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
def init_async_client(self):
|
def init_async_client(self):
|
||||||
from ._redis import get_redis_async_client
|
from ._redis import get_redis_async_client
|
||||||
|
@ -366,11 +373,12 @@ class RedisCache(BaseCache):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||||
)
|
)
|
||||||
|
json_cache_value = json.dumps(cache_value)
|
||||||
# Set the value with a TTL if it's provided.
|
# Set the value with a TTL if it's provided.
|
||||||
if ttl is not None:
|
if ttl is not None:
|
||||||
pipe.setex(cache_key, ttl, json.dumps(cache_value))
|
pipe.setex(cache_key, ttl, json_cache_value)
|
||||||
else:
|
else:
|
||||||
pipe.set(cache_key, json.dumps(cache_value))
|
pipe.set(cache_key, json_cache_value)
|
||||||
# Execute the pipeline and return the results.
|
# Execute the pipeline and return the results.
|
||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
@ -416,12 +424,12 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer() # logging done in here
|
await self.flush_cache_buffer() # logging done in here
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
result = await redis_client.incr(name=key, amount=value)
|
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
@ -803,9 +811,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# get the prompt
|
# get the prompt
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = ""
|
prompt = "".join(message["content"] for message in messages)
|
||||||
for message in messages:
|
|
||||||
prompt += message["content"]
|
|
||||||
|
|
||||||
# create an embedding for prompt
|
# create an embedding for prompt
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = litellm.embedding(
|
||||||
|
@ -840,9 +846,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# get the messages
|
# get the messages
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = ""
|
prompt = "".join(message["content"] for message in messages)
|
||||||
for message in messages:
|
|
||||||
prompt += message["content"]
|
|
||||||
|
|
||||||
# convert to embedding
|
# convert to embedding
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = litellm.embedding(
|
||||||
|
@ -902,9 +906,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# get the prompt
|
# get the prompt
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = ""
|
prompt = "".join(message["content"] for message in messages)
|
||||||
for message in messages:
|
|
||||||
prompt += message["content"]
|
|
||||||
# create an embedding for prompt
|
# create an embedding for prompt
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -957,9 +959,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
|
|
||||||
# get the messages
|
# get the messages
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
prompt = ""
|
prompt = "".join(message["content"] for message in messages)
|
||||||
for message in messages:
|
|
||||||
prompt += message["content"]
|
|
||||||
|
|
||||||
router_model_names = (
|
router_model_names = (
|
||||||
[m["model_name"] for m in llm_model_list]
|
[m["model_name"] for m in llm_model_list]
|
||||||
|
@ -1375,18 +1375,41 @@ class DualCache(BaseCache):
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_batch_set_cache(
|
||||||
|
self, cache_list: list, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch write values to the cache
|
||||||
|
"""
|
||||||
|
print_verbose(
|
||||||
|
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
await self.in_memory_cache.async_set_cache_pipeline(
|
||||||
|
cache_list=cache_list, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
await self.redis_cache.async_set_cache_pipeline(
|
||||||
|
cache_list=cache_list, ttl=kwargs.get("ttl", None)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_increment_cache(
|
async def async_increment_cache(
|
||||||
self, key, value: int, local_only: bool = False, **kwargs
|
self, key, value: float, local_only: bool = False, **kwargs
|
||||||
) -> int:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Key - the key in cache
|
Key - the key in cache
|
||||||
|
|
||||||
Value - int - the value you want to increment by
|
Value - float - the value you want to increment by
|
||||||
|
|
||||||
Returns - int - the incremented value
|
Returns - float - the incremented value
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result: int = value
|
result: float = value
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
result = await self.in_memory_cache.async_increment(
|
result = await self.in_memory_cache.async_increment(
|
||||||
key, value, **kwargs
|
key, value, **kwargs
|
||||||
|
|
|
@ -9,25 +9,12 @@
|
||||||
|
|
||||||
## LiteLLM versions of the OpenAI Exception Types
|
## LiteLLM versions of the OpenAI Exception Types
|
||||||
|
|
||||||
from openai import (
|
import openai
|
||||||
AuthenticationError,
|
|
||||||
BadRequestError,
|
|
||||||
NotFoundError,
|
|
||||||
RateLimitError,
|
|
||||||
APIStatusError,
|
|
||||||
OpenAIError,
|
|
||||||
APIError,
|
|
||||||
APITimeoutError,
|
|
||||||
APIConnectionError,
|
|
||||||
APIResponseValidationError,
|
|
||||||
UnprocessableEntityError,
|
|
||||||
PermissionDeniedError,
|
|
||||||
)
|
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(AuthenticationError): # type: ignore
|
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -39,7 +26,7 @@ class AuthenticationError(AuthenticationError): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# raise when invalid models passed, example gpt-8
|
# raise when invalid models passed, example gpt-8
|
||||||
class NotFoundError(NotFoundError): # type: ignore
|
class NotFoundError(openai.NotFoundError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
self.status_code = 404
|
self.status_code = 404
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -50,7 +37,7 @@ class NotFoundError(NotFoundError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class BadRequestError(BadRequestError): # type: ignore
|
class BadRequestError(openai.BadRequestError): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
self, message, model, llm_provider, response: Optional[httpx.Response] = None
|
self, message, model, llm_provider, response: Optional[httpx.Response] = None
|
||||||
):
|
):
|
||||||
|
@ -69,7 +56,7 @@ class BadRequestError(BadRequestError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
self.status_code = 422
|
self.status_code = 422
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -80,7 +67,7 @@ class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class Timeout(APITimeoutError): # type: ignore
|
class Timeout(openai.APITimeoutError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider):
|
def __init__(self, message, model, llm_provider):
|
||||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -96,7 +83,7 @@ class Timeout(APITimeoutError): # type: ignore
|
||||||
return str(self.message)
|
return str(self.message)
|
||||||
|
|
||||||
|
|
||||||
class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 403
|
self.status_code = 403
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -107,7 +94,7 @@ class PermissionDeniedError(PermissionDeniedError): # type:ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class RateLimitError(RateLimitError): # type: ignore
|
class RateLimitError(openai.RateLimitError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 429
|
self.status_code = 429
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -148,7 +135,7 @@ class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class ServiceUnavailableError(APIStatusError): # type: ignore
|
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 503
|
self.status_code = 503
|
||||||
self.message = message
|
self.message = message
|
||||||
|
@ -160,7 +147,7 @@ class ServiceUnavailableError(APIStatusError): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||||
class APIError(APIError): # type: ignore
|
class APIError(openai.APIError): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
self, status_code, message, llm_provider, model, request: httpx.Request
|
self, status_code, message, llm_provider, model, request: httpx.Request
|
||||||
):
|
):
|
||||||
|
@ -172,7 +159,7 @@ class APIError(APIError): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# raised if an invalid request (not get, delete, put, post) is made
|
# raised if an invalid request (not get, delete, put, post) is made
|
||||||
class APIConnectionError(APIConnectionError): # type: ignore
|
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, request: httpx.Request):
|
def __init__(self, message, llm_provider, model, request: httpx.Request):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
|
@ -182,7 +169,7 @@ class APIConnectionError(APIConnectionError): # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# raised if an invalid request (not get, delete, put, post) is made
|
# raised if an invalid request (not get, delete, put, post) is made
|
||||||
class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model):
|
def __init__(self, message, llm_provider, model):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.llm_provider = llm_provider
|
self.llm_provider = llm_provider
|
||||||
|
@ -192,7 +179,7 @@ class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||||
super().__init__(response=response, body=None, message=message)
|
super().__init__(response=response, body=None, message=message)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(OpenAIError): # type: ignore
|
class OpenAIError(openai.OpenAIError): # type: ignore
|
||||||
def __init__(self, original_exception):
|
def __init__(self, original_exception):
|
||||||
self.status_code = original_exception.http_status
|
self.status_code = original_exception.http_status
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -214,7 +201,7 @@ class BudgetExceededError(Exception):
|
||||||
|
|
||||||
|
|
||||||
## DEPRECATED ##
|
## DEPRECATED ##
|
||||||
class InvalidRequestError(BadRequestError): # type: ignore
|
class InvalidRequestError(openai.BadRequestError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider):
|
def __init__(self, message, model, llm_provider):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -4,18 +4,30 @@ import datetime
|
||||||
class AthinaLogger:
|
class AthinaLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"athina-api-key": self.athina_api_key,
|
"athina-api-key": self.athina_api_key,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
|
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
|
||||||
self.additional_keys = ["environment", "prompt_slug", "customer_id", "customer_user_id", "session_id", "external_reference_id", "context", "expected_response", "user_query"]
|
self.additional_keys = [
|
||||||
|
"environment",
|
||||||
|
"prompt_slug",
|
||||||
|
"customer_id",
|
||||||
|
"customer_user_id",
|
||||||
|
"session_id",
|
||||||
|
"external_reference_id",
|
||||||
|
"context",
|
||||||
|
"expected_response",
|
||||||
|
"user_query",
|
||||||
|
]
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_json = response_obj.model_dump() if response_obj else {}
|
response_json = response_obj.model_dump() if response_obj else {}
|
||||||
data = {
|
data = {
|
||||||
|
@ -23,19 +35,30 @@ class AthinaLogger:
|
||||||
"request": kwargs,
|
"request": kwargs,
|
||||||
"response": response_json,
|
"response": response_json,
|
||||||
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
||||||
"completion_tokens": response_json.get("usage", {}).get("completion_tokens"),
|
"completion_tokens": response_json.get("usage", {}).get(
|
||||||
|
"completion_tokens"
|
||||||
|
),
|
||||||
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if type(end_time) == datetime.datetime and type(start_time) == datetime.datetime:
|
if (
|
||||||
data["response_time"] = int((end_time - start_time).total_seconds() * 1000)
|
type(end_time) == datetime.datetime
|
||||||
|
and type(start_time) == datetime.datetime
|
||||||
|
):
|
||||||
|
data["response_time"] = int(
|
||||||
|
(end_time - start_time).total_seconds() * 1000
|
||||||
|
)
|
||||||
|
|
||||||
if "messages" in kwargs:
|
if "messages" in kwargs:
|
||||||
data["prompt"] = kwargs.get("messages", None)
|
data["prompt"] = kwargs.get("messages", None)
|
||||||
|
|
||||||
# Directly add tools or functions if present
|
# Directly add tools or functions if present
|
||||||
optional_params = kwargs.get("optional_params", {})
|
optional_params = kwargs.get("optional_params", {})
|
||||||
data.update((k, v) for k, v in optional_params.items() if k in ["tools", "functions"])
|
data.update(
|
||||||
|
(k, v)
|
||||||
|
for k, v in optional_params.items()
|
||||||
|
if k in ["tools", "functions"]
|
||||||
|
)
|
||||||
|
|
||||||
# Add additional metadata keys
|
# Add additional metadata keys
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||||
|
@ -44,11 +67,19 @@ class AthinaLogger:
|
||||||
if key in metadata:
|
if key in metadata:
|
||||||
data[key] = metadata[key]
|
data[key] = metadata[key]
|
||||||
|
|
||||||
response = requests.post(self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
response = requests.post(
|
||||||
|
self.athina_logging_url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=json.dumps(data, default=str),
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_verbose(f"Athina Logger Error - {response.text}, {response.status_code}")
|
print_verbose(
|
||||||
|
f"Athina Logger Error - {response.text}, {response.status_code}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
print_verbose(
|
||||||
|
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
pass
|
pass
|
|
@ -1,9 +1,8 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +16,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union, Optional
|
from typing import Literal, Union, Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
class GreenscaleLogger:
|
class GreenscaleLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"api-key": self.greenscale_api_key,
|
"api-key": self.greenscale_api_key,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||||
|
|
||||||
|
@ -19,13 +21,18 @@ class GreenscaleLogger:
|
||||||
data = {
|
data = {
|
||||||
"modelId": kwargs.get("model"),
|
"modelId": kwargs.get("model"),
|
||||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||||
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
|
"outputTokenCount": response_json.get("usage", {}).get(
|
||||||
|
"completion_tokens"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
|
data["timestamp"] = datetime.now(timezone.utc).strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ"
|
||||||
|
)
|
||||||
|
|
||||||
if type(end_time) == datetime and type(start_time) == datetime:
|
if type(end_time) == datetime and type(start_time) == datetime:
|
||||||
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
|
data["invocationLatency"] = int(
|
||||||
|
(end_time - start_time).total_seconds() * 1000
|
||||||
|
)
|
||||||
|
|
||||||
# Add additional metadata keys to tags
|
# Add additional metadata keys to tags
|
||||||
tags = []
|
tags = []
|
||||||
|
@ -37,15 +44,25 @@ class GreenscaleLogger:
|
||||||
elif key == "greenscale_application":
|
elif key == "greenscale_application":
|
||||||
data["application"] = value
|
data["application"] = value
|
||||||
else:
|
else:
|
||||||
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
|
tags.append(
|
||||||
|
{"key": key.replace("greenscale_", ""), "value": str(value)}
|
||||||
|
)
|
||||||
|
|
||||||
data["tags"] = tags
|
data["tags"] = tags
|
||||||
|
|
||||||
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
response = requests.post(
|
||||||
|
self.greenscale_logging_url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=json.dumps(data, default=str),
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
|
print_verbose(
|
||||||
|
f"Greenscale Logger Error - {response.text}, {response.status_code}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
print_verbose(
|
||||||
|
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
pass
|
pass
|
|
@ -1,10 +1,8 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Helicone
|
# On success, logs events to Helicone
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import copy
|
import copy
|
||||||
import traceback
|
import traceback
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
@ -262,6 +260,23 @@ class LangFuseLogger:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tags = []
|
tags = []
|
||||||
|
try:
|
||||||
|
metadata = copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
) # Avoid modifying the original metadata
|
||||||
|
except:
|
||||||
|
new_metadata = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if (
|
||||||
|
isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
or isinstance(value, str)
|
||||||
|
or isinstance(value, int)
|
||||||
|
or isinstance(value, float)
|
||||||
|
):
|
||||||
|
new_metadata[key] = copy.deepcopy(value)
|
||||||
|
metadata = new_metadata
|
||||||
|
|
||||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
|
@ -272,36 +287,9 @@ class LangFuseLogger:
|
||||||
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
||||||
|
|
||||||
if supports_tags:
|
if supports_tags:
|
||||||
metadata_tags = metadata.get("tags", [])
|
metadata_tags = metadata.pop("tags", [])
|
||||||
tags = metadata_tags
|
tags = metadata_tags
|
||||||
|
|
||||||
trace_name = metadata.get("trace_name", None)
|
|
||||||
trace_id = metadata.get("trace_id", None)
|
|
||||||
existing_trace_id = metadata.get("existing_trace_id", None)
|
|
||||||
if trace_name is None and existing_trace_id is None:
|
|
||||||
# just log `litellm-{call_type}` as the trace name
|
|
||||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
|
||||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
|
||||||
|
|
||||||
if existing_trace_id is not None:
|
|
||||||
trace_params = {"id": existing_trace_id}
|
|
||||||
else: # don't overwrite an existing trace
|
|
||||||
trace_params = {
|
|
||||||
"name": trace_name,
|
|
||||||
"input": input,
|
|
||||||
"user_id": metadata.get("trace_user_id", user_id),
|
|
||||||
"id": trace_id,
|
|
||||||
"session_id": metadata.get("session_id", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
if level == "ERROR":
|
|
||||||
trace_params["status_message"] = output
|
|
||||||
else:
|
|
||||||
trace_params["output"] = output
|
|
||||||
|
|
||||||
cost = kwargs.get("response_cost", None)
|
|
||||||
print_verbose(f"trace: {cost}")
|
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
# the raw metadata can contain circular references which leads to infinite recursion
|
# the raw metadata can contain circular references which leads to infinite recursion
|
||||||
# we clean out all extra litellm metadata params before logging
|
# we clean out all extra litellm metadata params before logging
|
||||||
|
@ -328,6 +316,77 @@ class LangFuseLogger:
|
||||||
else:
|
else:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
|
session_id = clean_metadata.pop("session_id", None)
|
||||||
|
trace_name = clean_metadata.pop("trace_name", None)
|
||||||
|
trace_id = clean_metadata.pop("trace_id", None)
|
||||||
|
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||||
|
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||||
|
debug = clean_metadata.pop("debug_langfuse", None)
|
||||||
|
mask_input = clean_metadata.pop("mask_input", False)
|
||||||
|
mask_output = clean_metadata.pop("mask_output", False)
|
||||||
|
|
||||||
|
if trace_name is None and existing_trace_id is None:
|
||||||
|
# just log `litellm-{call_type}` as the trace name
|
||||||
|
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||||
|
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||||
|
|
||||||
|
if existing_trace_id is not None:
|
||||||
|
trace_params = {"id": existing_trace_id}
|
||||||
|
|
||||||
|
# Update the following keys for this trace
|
||||||
|
for metadata_param_key in update_trace_keys:
|
||||||
|
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||||
|
if trace_param_key not in trace_params:
|
||||||
|
updated_trace_value = clean_metadata.pop(
|
||||||
|
metadata_param_key, None
|
||||||
|
)
|
||||||
|
if updated_trace_value is not None:
|
||||||
|
trace_params[trace_param_key] = updated_trace_value
|
||||||
|
|
||||||
|
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||||
|
for key in list(
|
||||||
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
|
clean_metadata.pop(key, None)
|
||||||
|
|
||||||
|
# Special keys that are found in the function arguments and not the metadata
|
||||||
|
if "input" in update_trace_keys:
|
||||||
|
trace_params["input"] = input if not mask_input else "redacted-by-litellm"
|
||||||
|
if "output" in update_trace_keys:
|
||||||
|
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||||
|
else: # don't overwrite an existing trace
|
||||||
|
trace_params = {
|
||||||
|
"id": trace_id,
|
||||||
|
"name": trace_name,
|
||||||
|
"session_id": session_id,
|
||||||
|
"input": input if not mask_input else "redacted-by-litellm",
|
||||||
|
"version": clean_metadata.pop(
|
||||||
|
"trace_version", clean_metadata.get("version", None)
|
||||||
|
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
for key in list(
|
||||||
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
|
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||||
|
key, None
|
||||||
|
)
|
||||||
|
|
||||||
|
if level == "ERROR":
|
||||||
|
trace_params["status_message"] = output
|
||||||
|
else:
|
||||||
|
trace_params["output"] = output if not mask_output else "redacted-by-litellm"
|
||||||
|
|
||||||
|
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||||
|
if "metadata" in trace_params:
|
||||||
|
# log the raw_metadata in the trace
|
||||||
|
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
||||||
|
else:
|
||||||
|
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
|
||||||
|
|
||||||
|
cost = kwargs.get("response_cost", None)
|
||||||
|
print_verbose(f"trace: {cost}")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm._langfuse_default_tags is not None
|
litellm._langfuse_default_tags is not None
|
||||||
and isinstance(litellm._langfuse_default_tags, list)
|
and isinstance(litellm._langfuse_default_tags, list)
|
||||||
|
@ -375,7 +434,6 @@ class LangFuseLogger:
|
||||||
"url": url,
|
"url": url,
|
||||||
"headers": clean_headers,
|
"headers": clean_headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
trace = self.Langfuse.trace(**trace_params)
|
trace = self.Langfuse.trace(**trace_params)
|
||||||
|
|
||||||
generation_id = None
|
generation_id = None
|
||||||
|
@ -387,7 +445,7 @@ class LangFuseLogger:
|
||||||
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
||||||
"total_cost": cost if supports_costs else None,
|
"total_cost": cost if supports_costs else None,
|
||||||
}
|
}
|
||||||
generation_name = metadata.get("generation_name", None)
|
generation_name = clean_metadata.pop("generation_name", None)
|
||||||
if generation_name is None:
|
if generation_name is None:
|
||||||
# just log `litellm-{call_type}` as the generation name
|
# just log `litellm-{call_type}` as the generation name
|
||||||
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||||
|
@ -402,20 +460,43 @@ class LangFuseLogger:
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"name": generation_name,
|
"name": generation_name,
|
||||||
"id": metadata.get("generation_id", generation_id),
|
"id": clean_metadata.pop("generation_id", generation_id),
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
"model": kwargs["model"],
|
"model": kwargs["model"],
|
||||||
"model_parameters": optional_params,
|
"model_parameters": optional_params,
|
||||||
"input": input,
|
"input": input if not mask_input else "redacted-by-litellm",
|
||||||
"output": output,
|
"output": output if not mask_output else "redacted-by-litellm",
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
"metadata": clean_metadata,
|
"metadata": clean_metadata,
|
||||||
"level": level,
|
"level": level,
|
||||||
|
"version": clean_metadata.pop("version", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
if supports_prompt:
|
if supports_prompt:
|
||||||
generation_params["prompt"] = metadata.get("prompt", None)
|
user_prompt = clean_metadata.pop("prompt", None)
|
||||||
|
if user_prompt is None:
|
||||||
|
pass
|
||||||
|
elif isinstance(user_prompt, dict):
|
||||||
|
from langfuse.model import (
|
||||||
|
TextPromptClient,
|
||||||
|
ChatPromptClient,
|
||||||
|
Prompt_Text,
|
||||||
|
Prompt_Chat,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_prompt.get("type", "") == "chat":
|
||||||
|
_prompt_chat = Prompt_Chat(**user_prompt)
|
||||||
|
generation_params["prompt"] = ChatPromptClient(
|
||||||
|
prompt=_prompt_chat
|
||||||
|
)
|
||||||
|
elif user_prompt.get("type", "") == "text":
|
||||||
|
_prompt_text = Prompt_Text(**user_prompt)
|
||||||
|
generation_params["prompt"] = TextPromptClient(
|
||||||
|
prompt=_prompt_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generation_params["prompt"] = user_prompt
|
||||||
|
|
||||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||||
generation_params["status_message"] = output
|
generation_params["status_message"] = output
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langsmith
|
# On success, logs events to Langsmith
|
||||||
import dotenv, os
|
import dotenv, os # type: ignore
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import types
|
import types
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def is_serializable(value):
|
def is_serializable(value):
|
||||||
|
@ -79,8 +76,6 @@ class LangsmithLogger:
|
||||||
except:
|
except:
|
||||||
response_obj = response_obj.dict() # type: ignore
|
response_obj = response_obj.dict() # type: ignore
|
||||||
|
|
||||||
print(f"response_obj: {response_obj}")
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"name": run_name,
|
"name": run_name,
|
||||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||||
|
@ -90,7 +85,6 @@ class LangsmithLogger:
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
}
|
}
|
||||||
print(f"data: {data}")
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.smith.langchain.com/runs",
|
"https://api.smith.langchain.com/runs",
|
||||||
|
|
|
@ -2,14 +2,10 @@
|
||||||
# On success + failure, log events to lunary.ai
|
# On success + failure, log events to lunary.ai
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import traceback
|
import traceback
|
||||||
import dotenv
|
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
# convert to {completion: xx, tokens: xx}
|
# convert to {completion: xx, tokens: xx}
|
||||||
def parse_usage(usage):
|
def parse_usage(usage):
|
||||||
|
@ -18,13 +14,33 @@ def parse_usage(usage):
|
||||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def parse_tool_calls(tool_calls):
|
||||||
|
if tool_calls is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tool_call(tool_call):
|
||||||
|
|
||||||
|
serialized = {
|
||||||
|
"type": tool_call.type,
|
||||||
|
"id": tool_call.id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
|
||||||
def parse_messages(input):
|
def parse_messages(input):
|
||||||
|
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clean_message(message):
|
def clean_message(message):
|
||||||
# if is strin, return as is
|
# if is string, return as is
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@ -38,9 +54,7 @@ def parse_messages(input):
|
||||||
|
|
||||||
# Only add tool_calls and function_call to res if they are set
|
# Only add tool_calls and function_call to res if they are set
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
serialized["tool_calls"] = message.get("tool_calls")
|
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||||
if message.get("function_call"):
|
|
||||||
serialized["function_call"] = message.get("function_call")
|
|
||||||
|
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
@ -62,14 +76,16 @@ class LunaryLogger:
|
||||||
version = importlib.metadata.version("lunary")
|
version = importlib.metadata.version("lunary")
|
||||||
# if version < 0.1.43 then raise ImportError
|
# if version < 0.1.43 then raise ImportError
|
||||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||||
print(
|
print( # noqa
|
||||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||||
)
|
)
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
self.lunary_client = lunary
|
self.lunary_client = lunary
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
print( # noqa
|
||||||
|
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||||
|
) # noqa
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
|
@ -93,8 +109,13 @@ class LunaryLogger:
|
||||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
optional_params = kwargs.get("optional_params", {})
|
||||||
metadata = litellm_params.get("metadata", {}) or {}
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
|
|
||||||
|
if optional_params:
|
||||||
|
# merge into extra
|
||||||
|
extra = {**extra, **optional_params}
|
||||||
|
|
||||||
tags = litellm_params.pop("tags", None) or []
|
tags = litellm_params.pop("tags", None) or []
|
||||||
|
|
||||||
if extra:
|
if extra:
|
||||||
|
@ -104,7 +125,7 @@ class LunaryLogger:
|
||||||
|
|
||||||
# keep only serializable types
|
# keep only serializable types
|
||||||
for param, value in extra.items():
|
for param, value in extra.items():
|
||||||
if not isinstance(value, (str, int, bool, float)):
|
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||||
try:
|
try:
|
||||||
extra[param] = str(value)
|
extra[param] = str(value)
|
||||||
except:
|
except:
|
||||||
|
@ -140,7 +161,7 @@ class LunaryLogger:
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
runtime="litellm",
|
runtime="litellm",
|
||||||
tags=tags,
|
tags=tags,
|
||||||
extra=extra,
|
params=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lunary_client.track_event(
|
self.lunary_client.track_event(
|
||||||
|
|
|
@ -2,10 +2,7 @@
|
||||||
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
||||||
|
|
||||||
import dotenv, os, json
|
import dotenv, os, json
|
||||||
import requests
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
@ -60,12 +57,16 @@ class OpenMeterLogger(CustomLogger):
|
||||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
|
||||||
|
if not subject:
|
||||||
|
raise Exception("OpenMeter: user is required")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"specversion": "1.0",
|
"specversion": "1.0",
|
||||||
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
||||||
"id": call_id,
|
"id": call_id,
|
||||||
"time": dt,
|
"time": dt,
|
||||||
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param
|
"subject": subject,
|
||||||
"source": "litellm-proxy",
|
"source": "litellm-proxy",
|
||||||
"data": {"model": model, "cost": cost, **usage},
|
"data": {"model": model, "cost": cost, **usage},
|
||||||
}
|
}
|
||||||
|
@ -80,15 +81,24 @@ class OpenMeterLogger(CustomLogger):
|
||||||
api_key = os.getenv("OPENMETER_API_KEY")
|
api_key = os.getenv("OPENMETER_API_KEY")
|
||||||
|
|
||||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||||
self.sync_http_handler.post(
|
_headers = {
|
||||||
url=_url,
|
|
||||||
data=_data,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/cloudevents+json",
|
"Content-Type": "application/cloudevents+json",
|
||||||
"Authorization": "Bearer {}".format(api_key),
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
},
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.sync_http_handler.post(
|
||||||
|
url=_url,
|
||||||
|
data=json.dumps(_data),
|
||||||
|
headers=_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||||
|
raise e
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||||
if _url.endswith("/"):
|
if _url.endswith("/"):
|
||||||
|
|
|
@ -3,9 +3,7 @@
|
||||||
# On success, log events to Prometheus
|
# On success, log events to Prometheus
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
@ -19,7 +17,6 @@ class PrometheusLogger:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print(f"in init prometheus metrics")
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||||
|
|
|
@ -4,9 +4,7 @@
|
||||||
|
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
@ -183,7 +181,6 @@ class PrometheusServicesLogger:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||||
print(f"received error payload: {payload.error}")
|
|
||||||
if self.mock_testing:
|
if self.mock_testing:
|
||||||
self.mock_testing_failure_calls += 1
|
self.mock_testing_failure_calls += 1
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class PromptLayerLogger:
|
class PromptLayerLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -32,7 +31,11 @@ class PromptLayerLogger:
|
||||||
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
||||||
|
|
||||||
# Remove "pl_tags" from metadata
|
# Remove "pl_tags" from metadata
|
||||||
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"}
|
metadata = {
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs["litellm_params"]["metadata"].items()
|
||||||
|
if k != "pl_tags"
|
||||||
|
}
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import os
|
||||||
import requests
|
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -1,35 +1,90 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# Class for sending Slack Alerts #
|
# Class for sending Slack Alerts #
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import copy
|
|
||||||
import traceback
|
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
import litellm
|
import litellm, threading
|
||||||
from typing import List, Literal, Any, Union, Optional, Dict
|
from typing import List, Literal, Any, Union, Optional, Dict
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
import datetime
|
import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime as dt, timedelta
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class SlackAlerting:
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlertingArgs(LiteLLMBase):
|
||||||
|
default_daily_report_frequency: int = 12 * 60 * 60 # 12 hours
|
||||||
|
daily_report_frequency: int = int(os.getenv("SLACK_DAILY_REPORT_FREQUENCY", default_daily_report_frequency))
|
||||||
|
report_check_interval: int = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
|
class DeploymentMetrics(LiteLLMBase):
|
||||||
|
"""
|
||||||
|
Metrics per deployment, stored in cache
|
||||||
|
|
||||||
|
Used for daily reporting
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""id of deployment in router model list"""
|
||||||
|
|
||||||
|
failed_request: bool
|
||||||
|
"""did it fail the request?"""
|
||||||
|
|
||||||
|
latency_per_output_token: Optional[float]
|
||||||
|
"""latency/output token of deployment"""
|
||||||
|
|
||||||
|
updated_at: dt
|
||||||
|
"""Current time of deployment being updated"""
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlertingCacheKeys(Enum):
|
||||||
|
"""
|
||||||
|
Enum for deployment daily metrics keys - {deployment_id}:{enum}
|
||||||
|
"""
|
||||||
|
|
||||||
|
failed_requests_key = "failed_requests_daily_metrics"
|
||||||
|
latency_key = "latency_daily_metrics"
|
||||||
|
report_sent_key = "daily_metrics_report_sent"
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlerting(CustomLogger):
|
||||||
|
"""
|
||||||
|
Class for sending Slack Alerts
|
||||||
|
"""
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
alerting_threshold: float = 300,
|
internal_usage_cache: Optional[DualCache] = None,
|
||||||
|
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||||
alerting: Optional[List] = [],
|
alerting: Optional[List] = [],
|
||||||
alert_types: Optional[
|
alert_types: List[
|
||||||
List[
|
|
||||||
Literal[
|
Literal[
|
||||||
"llm_exceptions",
|
"llm_exceptions",
|
||||||
"llm_too_slow",
|
"llm_too_slow",
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
]
|
"daily_reports",
|
||||||
]
|
]
|
||||||
] = [
|
] = [
|
||||||
"llm_exceptions",
|
"llm_exceptions",
|
||||||
|
@ -37,18 +92,23 @@ class SlackAlerting:
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
|
"daily_reports",
|
||||||
],
|
],
|
||||||
alert_to_webhook_url: Optional[
|
alert_to_webhook_url: Optional[
|
||||||
Dict
|
Dict
|
||||||
] = None, # if user wants to separate alerts to diff channels
|
] = None, # if user wants to separate alerts to diff channels
|
||||||
|
alerting_args={},
|
||||||
|
default_webhook_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
self.alert_types = alert_types
|
self.alert_types = alert_types
|
||||||
self.internal_usage_cache = DualCache()
|
self.internal_usage_cache = internal_usage_cache or DualCache()
|
||||||
self.async_http_handler = AsyncHTTPHandler()
|
self.async_http_handler = AsyncHTTPHandler()
|
||||||
self.alert_to_webhook_url = alert_to_webhook_url
|
self.alert_to_webhook_url = alert_to_webhook_url
|
||||||
pass
|
self.is_running = False
|
||||||
|
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||||
|
self.default_webhook_url = default_webhook_url
|
||||||
|
|
||||||
def update_values(
|
def update_values(
|
||||||
self,
|
self,
|
||||||
|
@ -56,6 +116,7 @@ class SlackAlerting:
|
||||||
alerting_threshold: Optional[float] = None,
|
alerting_threshold: Optional[float] = None,
|
||||||
alert_types: Optional[List] = None,
|
alert_types: Optional[List] = None,
|
||||||
alert_to_webhook_url: Optional[Dict] = None,
|
alert_to_webhook_url: Optional[Dict] = None,
|
||||||
|
alerting_args: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
if alerting is not None:
|
if alerting is not None:
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
|
@ -63,7 +124,8 @@ class SlackAlerting:
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
if alert_types is not None:
|
if alert_types is not None:
|
||||||
self.alert_types = alert_types
|
self.alert_types = alert_types
|
||||||
|
if alerting_args is not None:
|
||||||
|
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||||
if alert_to_webhook_url is not None:
|
if alert_to_webhook_url is not None:
|
||||||
# update the dict
|
# update the dict
|
||||||
if self.alert_to_webhook_url is None:
|
if self.alert_to_webhook_url is None:
|
||||||
|
@ -90,18 +152,23 @@ class SlackAlerting:
|
||||||
|
|
||||||
def _add_langfuse_trace_id_to_alert(
|
def _add_langfuse_trace_id_to_alert(
|
||||||
self,
|
self,
|
||||||
request_info: str,
|
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
kwargs: Optional[dict] = None,
|
) -> Optional[str]:
|
||||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
"""
|
||||||
start_time: Optional[datetime.datetime] = None,
|
Returns langfuse trace url
|
||||||
end_time: Optional[datetime.datetime] = None,
|
"""
|
||||||
):
|
|
||||||
# do nothing for now
|
# do nothing for now
|
||||||
pass
|
if (
|
||||||
return request_info
|
request_data is not None
|
||||||
|
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||||
|
):
|
||||||
|
trace_id = request_data["metadata"]["trace_id"]
|
||||||
|
if litellm.utils.langFuseLogger is not None:
|
||||||
|
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||||
|
return f"{base_url}/trace/{trace_id}"
|
||||||
|
return None
|
||||||
|
|
||||||
def _response_taking_too_long_callback(
|
def _response_taking_too_long_callback_helper(
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
start_time,
|
start_time,
|
||||||
|
@ -166,12 +233,14 @@ class SlackAlerting:
|
||||||
return
|
return
|
||||||
|
|
||||||
time_difference_float, model, api_base, messages = (
|
time_difference_float, model, api_base, messages = (
|
||||||
self._response_taking_too_long_callback(
|
self._response_taking_too_long_callback_helper(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if litellm.turn_off_message_logging:
|
||||||
|
messages = "Message not logged. `litellm.turn_off_message_logging=True`."
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||||
if time_difference_float > self.alerting_threshold:
|
if time_difference_float > self.alerting_threshold:
|
||||||
|
@ -182,6 +251,9 @@ class SlackAlerting:
|
||||||
and "metadata" in kwargs["litellm_params"]
|
and "metadata" in kwargs["litellm_params"]
|
||||||
):
|
):
|
||||||
_metadata = kwargs["litellm_params"]["metadata"]
|
_metadata = kwargs["litellm_params"]["metadata"]
|
||||||
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
|
request_info=request_info, metadata=_metadata
|
||||||
|
)
|
||||||
|
|
||||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||||
metadata=_metadata
|
metadata=_metadata
|
||||||
|
@ -196,8 +268,178 @@ class SlackAlerting:
|
||||||
alert_type="llm_too_slow",
|
alert_type="llm_too_slow",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def log_failure_event(self, original_exception: Exception):
|
async def async_update_daily_reports(
|
||||||
pass
|
self, deployment_metrics: DeploymentMetrics
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Store the perf by deployment in cache
|
||||||
|
- Number of failed requests per deployment
|
||||||
|
- Latency / output tokens per deployment
|
||||||
|
|
||||||
|
'deployment_id:daily_metrics:failed_requests'
|
||||||
|
'deployment_id:daily_metrics:latency_per_output_token'
|
||||||
|
|
||||||
|
Returns
|
||||||
|
int - count of metrics set (1 - if just latency, 2 - if failed + latency)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_val = 0
|
||||||
|
try:
|
||||||
|
## FAILED REQUESTS ##
|
||||||
|
if deployment_metrics.failed_request:
|
||||||
|
await self.internal_usage_cache.async_increment_cache(
|
||||||
|
key="{}:{}".format(
|
||||||
|
deployment_metrics.id,
|
||||||
|
SlackAlertingCacheKeys.failed_requests_key.value,
|
||||||
|
),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return_val += 1
|
||||||
|
|
||||||
|
## LATENCY ##
|
||||||
|
if deployment_metrics.latency_per_output_token is not None:
|
||||||
|
await self.internal_usage_cache.async_increment_cache(
|
||||||
|
key="{}:{}".format(
|
||||||
|
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
|
||||||
|
),
|
||||||
|
value=deployment_metrics.latency_per_output_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return_val += 1
|
||||||
|
|
||||||
|
return return_val
|
||||||
|
except Exception as e:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def send_daily_reports(self, router) -> bool:
|
||||||
|
"""
|
||||||
|
Send a daily report on:
|
||||||
|
- Top 5 deployments with most failed requests
|
||||||
|
- Top 5 slowest deployments (normalized by latency/output tokens)
|
||||||
|
|
||||||
|
Get the value from redis cache (if available) or in-memory and send it
|
||||||
|
|
||||||
|
Cleanup:
|
||||||
|
- reset values in cache -> prevent memory leak
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True -> if successfuly sent
|
||||||
|
False -> if not sent
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids = router.get_model_ids()
|
||||||
|
|
||||||
|
# get keys
|
||||||
|
failed_request_keys = [
|
||||||
|
"{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
|
||||||
|
for id in ids
|
||||||
|
]
|
||||||
|
latency_keys = [
|
||||||
|
"{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
|
||||||
|
]
|
||||||
|
|
||||||
|
combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
|
||||||
|
|
||||||
|
combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
|
||||||
|
keys=combined_metrics_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
all_none = True
|
||||||
|
for val in combined_metrics_values:
|
||||||
|
if val is not None:
|
||||||
|
all_none = False
|
||||||
|
|
||||||
|
if all_none:
|
||||||
|
return False
|
||||||
|
|
||||||
|
failed_request_values = combined_metrics_values[
|
||||||
|
: len(failed_request_keys)
|
||||||
|
] # # [1, 2, None, ..]
|
||||||
|
latency_values = combined_metrics_values[len(failed_request_keys) :]
|
||||||
|
|
||||||
|
# find top 5 failed
|
||||||
|
## Replace None values with a placeholder value (-1 in this case)
|
||||||
|
placeholder_value = 0
|
||||||
|
replaced_failed_values = [
|
||||||
|
value if value is not None else placeholder_value
|
||||||
|
for value in failed_request_values
|
||||||
|
]
|
||||||
|
|
||||||
|
## Get the indices of top 5 keys with the highest numerical values (ignoring None values)
|
||||||
|
top_5_failed = sorted(
|
||||||
|
range(len(replaced_failed_values)),
|
||||||
|
key=lambda i: replaced_failed_values[i],
|
||||||
|
reverse=True,
|
||||||
|
)[:5]
|
||||||
|
|
||||||
|
# find top 5 slowest
|
||||||
|
# Replace None values with a placeholder value (-1 in this case)
|
||||||
|
placeholder_value = 0
|
||||||
|
replaced_slowest_values = [
|
||||||
|
value if value is not None else placeholder_value
|
||||||
|
for value in latency_values
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get the indices of top 5 values with the highest numerical values (ignoring None values)
|
||||||
|
top_5_slowest = sorted(
|
||||||
|
range(len(replaced_slowest_values)),
|
||||||
|
key=lambda i: replaced_slowest_values[i],
|
||||||
|
reverse=True,
|
||||||
|
)[:5]
|
||||||
|
|
||||||
|
# format alert -> return the litellm model name + api base
|
||||||
|
message = f"\n\nHere are today's key metrics 📈: \n\n"
|
||||||
|
|
||||||
|
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n"
|
||||||
|
for i in range(len(top_5_failed)):
|
||||||
|
key = failed_request_keys[top_5_failed[i]].split(":")[0]
|
||||||
|
_deployment = router.get_model_info(key)
|
||||||
|
if isinstance(_deployment, dict):
|
||||||
|
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
api_base = litellm.get_api_base(
|
||||||
|
model=deployment_name,
|
||||||
|
optional_params=(
|
||||||
|
_deployment["litellm_params"] if _deployment is not None else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if api_base is None:
|
||||||
|
api_base = ""
|
||||||
|
value = replaced_failed_values[top_5_failed[i]]
|
||||||
|
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
|
||||||
|
|
||||||
|
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n"
|
||||||
|
for i in range(len(top_5_slowest)):
|
||||||
|
key = latency_keys[top_5_slowest[i]].split(":")[0]
|
||||||
|
_deployment = router.get_model_info(key)
|
||||||
|
if _deployment is not None:
|
||||||
|
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||||
|
else:
|
||||||
|
deployment_name = ""
|
||||||
|
api_base = litellm.get_api_base(
|
||||||
|
model=deployment_name,
|
||||||
|
optional_params=(
|
||||||
|
_deployment["litellm_params"] if _deployment is not None else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
value = round(replaced_slowest_values[top_5_slowest[i]], 3)
|
||||||
|
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
|
||||||
|
|
||||||
|
# cache cleanup -> reset values to 0
|
||||||
|
latency_cache_keys = [(key, 0) for key in latency_keys]
|
||||||
|
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
||||||
|
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
||||||
|
await self.internal_usage_cache.async_batch_set_cache(
|
||||||
|
cache_list=combined_metrics_cache_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
# send alert
|
||||||
|
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def response_taking_too_long(
|
async def response_taking_too_long(
|
||||||
self,
|
self,
|
||||||
|
@ -221,6 +463,11 @@ class SlackAlerting:
|
||||||
messages = messages[:100]
|
messages = messages[:100]
|
||||||
except:
|
except:
|
||||||
messages = ""
|
messages = ""
|
||||||
|
|
||||||
|
if litellm.turn_off_message_logging:
|
||||||
|
messages = (
|
||||||
|
"Message not logged. `litellm.turn_off_message_logging=True`."
|
||||||
|
)
|
||||||
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
||||||
else:
|
else:
|
||||||
request_info = ""
|
request_info = ""
|
||||||
|
@ -255,6 +502,11 @@ class SlackAlerting:
|
||||||
# in that case we fallback to the api base set in the request metadata
|
# in that case we fallback to the api base set in the request metadata
|
||||||
_metadata = request_data["metadata"]
|
_metadata = request_data["metadata"]
|
||||||
_api_base = _metadata.get("api_base", "")
|
_api_base = _metadata.get("api_base", "")
|
||||||
|
|
||||||
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
|
request_info=request_info, metadata=_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if _api_base is None:
|
if _api_base is None:
|
||||||
_api_base = ""
|
_api_base = ""
|
||||||
request_info += f"\nAPI Base: `{_api_base}`"
|
request_info += f"\nAPI Base: `{_api_base}`"
|
||||||
|
@ -264,14 +516,13 @@ class SlackAlerting:
|
||||||
)
|
)
|
||||||
|
|
||||||
if "langfuse" in litellm.success_callback:
|
if "langfuse" in litellm.success_callback:
|
||||||
request_info = self._add_langfuse_trace_id_to_alert(
|
langfuse_url = self._add_langfuse_trace_id_to_alert(
|
||||||
request_info=request_info,
|
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
type="hanging_request",
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if langfuse_url is not None:
|
||||||
|
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
||||||
|
|
||||||
# add deployment latencies to alert
|
# add deployment latencies to alert
|
||||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||||
metadata=request_data.get("metadata", {})
|
metadata=request_data.get("metadata", {})
|
||||||
|
@ -404,6 +655,53 @@ class SlackAlerting:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def model_added_alert(self, model_name: str, litellm_model_name: str):
|
||||||
|
model_info = litellm.model_cost.get(litellm_model_name, {})
|
||||||
|
model_info_str = ""
|
||||||
|
for k, v in model_info.items():
|
||||||
|
if k == "input_cost_per_token" or k == "output_cost_per_token":
|
||||||
|
# when converting to string it should not be 1.63e-06
|
||||||
|
v = "{:.8f}".format(v)
|
||||||
|
|
||||||
|
model_info_str += f"{k}: {v}\n"
|
||||||
|
|
||||||
|
message = f"""
|
||||||
|
*🚅 New Model Added*
|
||||||
|
Model Name: `{model_name}`
|
||||||
|
|
||||||
|
Usage OpenAI Python SDK:
|
||||||
|
```
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="your_api_key",
|
||||||
|
base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="{model_name}", # model to send to the proxy
|
||||||
|
messages = [
|
||||||
|
{{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Model Info:
|
||||||
|
```
|
||||||
|
{model_info_str}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.send_alert(
|
||||||
|
message=message, level="Low", alert_type="new_model_added"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def model_removed_alert(self, model_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
async def send_alert(
|
async def send_alert(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
|
@ -414,7 +712,11 @@ class SlackAlerting:
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
|
"daily_reports",
|
||||||
|
"new_model_added",
|
||||||
|
"cooldown_deployment",
|
||||||
],
|
],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||||
|
@ -439,9 +741,16 @@ class SlackAlerting:
|
||||||
# Get the current timestamp
|
# Get the current timestamp
|
||||||
current_time = datetime.now().strftime("%H:%M:%S")
|
current_time = datetime.now().strftime("%H:%M:%S")
|
||||||
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||||||
|
if alert_type == "daily_reports" or alert_type == "new_model_added":
|
||||||
|
formatted_message = message
|
||||||
|
else:
|
||||||
formatted_message = (
|
formatted_message = (
|
||||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||||
if _proxy_base_url is not None:
|
if _proxy_base_url is not None:
|
||||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||||
|
|
||||||
|
@ -451,6 +760,8 @@ class SlackAlerting:
|
||||||
and alert_type in self.alert_to_webhook_url
|
and alert_type in self.alert_to_webhook_url
|
||||||
):
|
):
|
||||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||||
|
elif self.default_webhook_url is not None:
|
||||||
|
slack_webhook_url = self.default_webhook_url
|
||||||
else:
|
else:
|
||||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||||
|
|
||||||
|
@ -468,3 +779,201 @@ class SlackAlerting:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print("Error sending slack alert. Error=", response.text) # noqa
|
print("Error sending slack alert. Error=", response.text) # noqa
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""Log deployment latency"""
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
model_id = (
|
||||||
|
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||||
|
)
|
||||||
|
response_s: timedelta = end_time - start_time
|
||||||
|
|
||||||
|
final_value = response_s
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, litellm.ModelResponse):
|
||||||
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
|
final_value = float(response_s.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
|
await self.async_update_daily_reports(
|
||||||
|
DeploymentMetrics(
|
||||||
|
id=model_id,
|
||||||
|
failed_request=False,
|
||||||
|
latency_per_output_token=final_value,
|
||||||
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""Log failure + deployment latency"""
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
model_id = (
|
||||||
|
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||||
|
)
|
||||||
|
await self.async_update_daily_reports(
|
||||||
|
DeploymentMetrics(
|
||||||
|
id=model_id,
|
||||||
|
failed_request=True,
|
||||||
|
latency_per_output_token=None,
|
||||||
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- True -> report sent
|
||||||
|
- False -> report not sent
|
||||||
|
"""
|
||||||
|
report_sent_bool = False
|
||||||
|
|
||||||
|
report_sent = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value
|
||||||
|
) # None | datetime
|
||||||
|
|
||||||
|
current_time = litellm.utils.get_utc_datetime()
|
||||||
|
|
||||||
|
if report_sent is None:
|
||||||
|
_current_time = current_time.isoformat()
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||||
|
value=_current_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# check if current time - interval >= time last sent
|
||||||
|
delta = current_time - timedelta(
|
||||||
|
seconds=self.alerting_args.daily_report_frequency
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(report_sent, str):
|
||||||
|
report_sent = dt.fromisoformat(report_sent)
|
||||||
|
|
||||||
|
if delta >= report_sent:
|
||||||
|
# Sneak in the reporting logic here
|
||||||
|
await self.send_daily_reports(router=llm_router)
|
||||||
|
# Also, don't forget to update the report_sent time after sending the report!
|
||||||
|
_current_time = current_time.isoformat()
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||||
|
value=_current_time,
|
||||||
|
)
|
||||||
|
report_sent_bool = True
|
||||||
|
|
||||||
|
return report_sent_bool
|
||||||
|
|
||||||
|
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
||||||
|
"""
|
||||||
|
If 'daily_reports' enabled
|
||||||
|
|
||||||
|
Ping redis cache every 5 minutes to check if we should send the report
|
||||||
|
|
||||||
|
If yes -> call send_daily_report()
|
||||||
|
"""
|
||||||
|
if llm_router is None or self.alert_types is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
while True:
|
||||||
|
await self._run_scheduler_helper(llm_router=llm_router)
|
||||||
|
interval = random.randint(
|
||||||
|
self.alerting_args.report_check_interval - 3,
|
||||||
|
self.alerting_args.report_check_interval + 3,
|
||||||
|
) # shuffle to prevent collisions
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def send_weekly_spend_report(self):
|
||||||
|
""" """
|
||||||
|
try:
|
||||||
|
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
|
||||||
|
|
||||||
|
todays_date = datetime.datetime.now().date()
|
||||||
|
week_before = todays_date - datetime.timedelta(days=7)
|
||||||
|
|
||||||
|
weekly_spend_per_team, weekly_spend_per_tag = (
|
||||||
|
await _get_spend_report_for_time_range(
|
||||||
|
start_date=week_before.strftime("%Y-%m-%d"),
|
||||||
|
end_date=todays_date.strftime("%Y-%m-%d"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_weekly_spend_message = f"*💸 Weekly Spend Report for `{week_before.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` *\n"
|
||||||
|
|
||||||
|
if weekly_spend_per_team is not None:
|
||||||
|
_weekly_spend_message += "\n*Team Spend Report:*\n"
|
||||||
|
for spend in weekly_spend_per_team:
|
||||||
|
_team_spend = spend["total_spend"]
|
||||||
|
_team_spend = float(_team_spend)
|
||||||
|
# round to 4 decimal places
|
||||||
|
_team_spend = round(_team_spend, 4)
|
||||||
|
_weekly_spend_message += (
|
||||||
|
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if weekly_spend_per_tag is not None:
|
||||||
|
_weekly_spend_message += "\n*Tag Spend Report:*\n"
|
||||||
|
for spend in weekly_spend_per_tag:
|
||||||
|
_tag_spend = spend["total_spend"]
|
||||||
|
_tag_spend = float(_tag_spend)
|
||||||
|
# round to 4 decimal places
|
||||||
|
_tag_spend = round(_tag_spend, 4)
|
||||||
|
_weekly_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
||||||
|
|
||||||
|
await self.send_alert(
|
||||||
|
message=_weekly_spend_message,
|
||||||
|
level="Low",
|
||||||
|
alert_type="daily_reports",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||||
|
|
||||||
|
async def send_monthly_spend_report(self):
|
||||||
|
""" """
|
||||||
|
try:
|
||||||
|
from calendar import monthrange
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import _get_spend_report_for_time_range
|
||||||
|
|
||||||
|
todays_date = datetime.datetime.now().date()
|
||||||
|
first_day_of_month = todays_date.replace(day=1)
|
||||||
|
_, last_day_of_month = monthrange(todays_date.year, todays_date.month)
|
||||||
|
last_day_of_month = first_day_of_month + datetime.timedelta(
|
||||||
|
days=last_day_of_month - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
monthly_spend_per_team, monthly_spend_per_tag = (
|
||||||
|
await _get_spend_report_for_time_range(
|
||||||
|
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
||||||
|
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
||||||
|
|
||||||
|
if monthly_spend_per_team is not None:
|
||||||
|
_spend_message += "\n*Team Spend Report:*\n"
|
||||||
|
for spend in monthly_spend_per_team:
|
||||||
|
_team_spend = spend["total_spend"]
|
||||||
|
_team_spend = float(_team_spend)
|
||||||
|
# round to 4 decimal places
|
||||||
|
_team_spend = round(_team_spend, 4)
|
||||||
|
_spend_message += (
|
||||||
|
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if monthly_spend_per_tag is not None:
|
||||||
|
_spend_message += "\n*Tag Spend Report:*\n"
|
||||||
|
for spend in monthly_spend_per_tag:
|
||||||
|
_tag_spend = spend["total_spend"]
|
||||||
|
_tag_spend = float(_tag_spend)
|
||||||
|
# round to 4 decimal places
|
||||||
|
_tag_spend = round(_tag_spend, 4)
|
||||||
|
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
||||||
|
|
||||||
|
await self.send_alert(
|
||||||
|
message=_spend_message,
|
||||||
|
level="Low",
|
||||||
|
alert_type="daily_reports",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error("Error sending weekly spend report", e)
|
||||||
|
|
|
@ -2,9 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -21,11 +21,11 @@ try:
|
||||||
# contains a (known) object attribute
|
# contains a (known) object attribute
|
||||||
object: Literal["chat.completion", "edit", "text_completion"]
|
object: Literal["chat.completion", "edit", "text_completion"]
|
||||||
|
|
||||||
def __getitem__(self, key: K) -> V:
|
def __getitem__(self, key: K) -> V: ... # noqa
|
||||||
... # pragma: no cover
|
|
||||||
|
|
||||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
def get( # noqa
|
||||||
... # pragma: no cover
|
self, key: K, default: Optional[V] = None
|
||||||
|
) -> Optional[V]: ... # pragma: no cover
|
||||||
|
|
||||||
class OpenAIRequestResponseResolver:
|
class OpenAIRequestResponseResolver:
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -173,12 +173,11 @@ except:
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os, types, traceback
|
import os, types, traceback
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, httpx
|
import time, httpx # type: ignore
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message
|
from litellm.utils import ModelResponse, Choices, Message
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AlephAlphaError(Exception):
|
class AlephAlphaError(Exception):
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy
|
import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List, Union
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
|
@ -84,6 +84,51 @@ class AnthropicConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
if (
|
||||||
|
value == "\n"
|
||||||
|
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
continue
|
||||||
|
value = [value]
|
||||||
|
elif isinstance(value, list):
|
||||||
|
new_v = []
|
||||||
|
for v in value:
|
||||||
|
if (
|
||||||
|
v == "\n"
|
||||||
|
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
continue
|
||||||
|
new_v.append(v)
|
||||||
|
if len(new_v) > 0:
|
||||||
|
value = new_v
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
# makes headers for API call
|
||||||
def validate_environment(api_key, user_headers):
|
def validate_environment(api_key, user_headers):
|
||||||
|
@ -106,19 +151,23 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_response(
|
def process_streaming_response(
|
||||||
self,
|
self,
|
||||||
model,
|
model: str,
|
||||||
response,
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response,
|
model_response: ModelResponse,
|
||||||
_is_function_call,
|
stream: bool,
|
||||||
stream,
|
logging_obj: litellm.utils.Logging,
|
||||||
logging_obj,
|
optional_params: dict,
|
||||||
api_key,
|
api_key: str,
|
||||||
data,
|
data: Union[dict, str],
|
||||||
messages,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
):
|
encoding,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
"""
|
||||||
|
Return stream object for tool-calling + streaming
|
||||||
|
"""
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -134,17 +183,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
message=response.text, status_code=response.status_code
|
message=response.text, status_code=response.status_code
|
||||||
)
|
)
|
||||||
if "error" in completion_response:
|
|
||||||
raise AnthropicError(
|
|
||||||
message=str(completion_response["error"]),
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
elif len(completion_response["content"]) == 0:
|
|
||||||
raise AnthropicError(
|
|
||||||
message="No content in response",
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text_content = ""
|
text_content = ""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for content in completion_response["content"]:
|
for content in completion_response["content"]:
|
||||||
|
@ -162,7 +200,11 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
_message = litellm.Message(
|
_message = litellm.Message(
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
content=text_content or None,
|
content=text_content or None,
|
||||||
|
@ -176,12 +218,10 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
|
||||||
if _is_function_call and stream:
|
|
||||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
# return an iterator
|
# return an iterator
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||||
0
|
0
|
||||||
].finish_reason
|
].finish_reason
|
||||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
@ -220,6 +260,77 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
custom_llm_provider="cached_response",
|
custom_llm_provider="cached_response",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unprocessable response object - {}".format(response.text),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
for content in completion_response["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["text"]
|
||||||
|
## TOOL CALLING
|
||||||
|
elif content["type"] == "tool_use":
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": content["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": content["name"],
|
||||||
|
"arguments": json.dumps(content["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=text_content or None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
|
"content"
|
||||||
|
] # allow user to access raw anthropic tool calling response
|
||||||
|
|
||||||
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
completion_response["stop_reason"]
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
|
@ -233,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
|
@ -249,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -291,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
return self.process_response(
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -327,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -486,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
return self.process_response(
|
|
||||||
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
|
|
|
@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_response(
|
def _process_response(
|
||||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||||
):
|
):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Union, Any
|
from typing import Optional, Union, Any, Literal
|
||||||
import types, requests
|
import types, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -8,14 +8,16 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
get_secret,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional, BinaryIO
|
from typing import Callable, Optional, BinaryIO, List
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm, json
|
import litellm, json
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
import uuid
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -105,6 +107,12 @@ class AzureOpenAIConfig(OpenAIConfig):
|
||||||
optional_params["azure_ad_token"] = value
|
optional_params["azure_ad_token"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||||
|
"""
|
||||||
|
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||||
|
|
||||||
|
|
||||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
# azure_client_params = {
|
# azure_client_params = {
|
||||||
|
@ -126,6 +134,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
return azure_client_params
|
return azure_client_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
|
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||||
|
|
||||||
|
if azure_client_id is None or azure_tenant is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||||
|
)
|
||||||
|
|
||||||
|
oidc_token = get_secret(azure_ad_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
)
|
||||||
|
|
||||||
|
req_token = httpx.post(
|
||||||
|
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||||
|
data={
|
||||||
|
"client_id": azure_client_id,
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"scope": "https://cognitiveservices.azure.com/.default",
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"client_assertion": oidc_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if req_token.status_code != 200:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=req_token.status_code,
|
||||||
|
message=req_token.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||||
|
|
||||||
|
if possible_azure_ad_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token not returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
return possible_azure_ad_token
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -137,6 +190,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["api-key"] = api_key
|
headers["api-key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
@ -151,7 +206,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout,
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params,
|
optional_params,
|
||||||
litellm_params,
|
litellm_params,
|
||||||
|
@ -189,6 +244,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
|
@ -276,6 +334,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -351,6 +411,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
|
@ -422,6 +484,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -478,6 +542,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
@ -599,6 +665,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -755,6 +823,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation == True:
|
||||||
|
@ -833,6 +903,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if max_retries is not None:
|
if max_retries is not None:
|
||||||
|
@ -952,6 +1024,81 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_headers(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
timeout: float,
|
||||||
|
mode: str,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
input: Optional[list] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
client_session = litellm.client_session or httpx.Client(
|
||||||
|
transport=CustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
|
)
|
||||||
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
## build base url - assume api base includes resource name
|
||||||
|
if not api_base.endswith("/"):
|
||||||
|
api_base += "/"
|
||||||
|
api_base += f"{model}"
|
||||||
|
client = AzureOpenAI(
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
model = None
|
||||||
|
# cloudflare ai gateway, needs model=None
|
||||||
|
else:
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# only run this check if it's not cloudflare ai gateway
|
||||||
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
completion = None
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hey"}]
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
raise Exception("invalid completion response")
|
||||||
|
|
||||||
|
if (
|
||||||
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def ahealth_check(
|
async def ahealth_check(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
@ -963,7 +1110,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
):
|
) -> dict:
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
)
|
)
|
||||||
|
@ -1040,4 +1187,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
"x-ratelimit-remaining-tokens"
|
"x-ratelimit-remaining-tokens"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, Union, Any
|
from typing import Optional, Union, Any
|
||||||
import types, requests
|
import types, requests # type: ignore
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
|
|
@ -1,12 +1,32 @@
|
||||||
## This is a template base class to be used for adding new LLM providers via API calls
|
## This is a template base class to be used for adding new LLM providers via API calls
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx, requests
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
from litellm.utils import Logging
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
_client_session: Optional[httpx.Client] = None
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> litellm.utils.ModelResponse:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
|
@ -4,7 +4,13 @@ from enum import Enum
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional, Any, Union, List
|
from typing import Callable, Optional, Any, Union, List
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
get_secret,
|
||||||
|
Usage,
|
||||||
|
ImageResponse,
|
||||||
|
map_finish_reason,
|
||||||
|
)
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -46,6 +52,16 @@ class AmazonBedrockGlobalConfig:
|
||||||
optional_params[mapped_params[param]] = value
|
optional_params[mapped_params[param]] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://www.aws-services.info/bedrock.html
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"eu-west-1",
|
||||||
|
"eu-west-3",
|
||||||
|
"eu-central-1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AmazonTitanConfig:
|
class AmazonTitanConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -157,6 +173,7 @@ class AmazonAnthropicClaude3Config:
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
@ -524,6 +541,17 @@ class AmazonStabilityConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_header(headers):
|
||||||
|
"""Closure to capture the headers and add them."""
|
||||||
|
|
||||||
|
def callback(request, **kwargs):
|
||||||
|
"""Actual callback function that Boto3 will call."""
|
||||||
|
for header_name, header_value in headers.items():
|
||||||
|
request.headers.add_header(header_name, header_value)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
def init_bedrock_client(
|
def init_bedrock_client(
|
||||||
region_name=None,
|
region_name=None,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -533,12 +561,13 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
timeout: Optional[int] = None,
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
# Define the list of parameters to check
|
# Define the list of parameters to check
|
||||||
params_to_check = [
|
params_to_check = [
|
||||||
|
@ -549,6 +578,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
|
@ -564,6 +594,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### SET REGION NAME
|
### SET REGION NAME
|
||||||
|
@ -592,10 +623,48 @@ def init_bedrock_client(
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
if isinstance(timeout, float):
|
||||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||||
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
|
config = boto3.session.Config(
|
||||||
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = boto3.session.Config()
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_role_name is not None and aws_session_name is not None:
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
# use sts if role name passed in
|
# use sts if role name passed in
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -647,6 +716,10 @@ def init_bedrock_client(
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
if extra_headers:
|
||||||
|
client.meta.events.register(
|
||||||
|
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||||
|
)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -710,6 +783,7 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
@ -725,6 +799,7 @@ def completion(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = optional_params.pop("aws_bedrock_client", None)
|
client = optional_params.pop("aws_bedrock_client", None)
|
||||||
|
@ -739,6 +814,8 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1043,7 +1120,9 @@ def completion(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response["finish_reason"] = response_body["stop_reason"]
|
model_response["finish_reason"] = map_finish_reason(
|
||||||
|
response_body["stop_reason"]
|
||||||
|
)
|
||||||
_usage = litellm.Usage(
|
_usage = litellm.Usage(
|
||||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||||
completion_tokens=response_body["usage"]["output_tokens"],
|
completion_tokens=response_body["usage"]["output_tokens"],
|
||||||
|
@ -1194,7 +1273,7 @@ def _embedding_func_single(
|
||||||
"input_type", "search_document"
|
"input_type", "search_document"
|
||||||
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||||
data = {"texts": [input], **inference_params} # type: ignore
|
data = {"texts": [input], **inference_params} # type: ignore
|
||||||
body = json.dumps(data).encode("utf-8")
|
body = json.dumps(data).encode("utf-8") # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
|
@ -1258,6 +1337,7 @@ def embedding(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1265,6 +1345,7 @@ def embedding(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
|
@ -1347,6 +1428,7 @@ def image_generation(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1354,6 +1436,7 @@ def image_generation(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -1386,7 +1469,7 @@ def image_generation(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
body={body},
|
body={body}, # type: ignore
|
||||||
modelId={modelId},
|
modelId={modelId},
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
|
|
733
litellm/llms/bedrock_httpx.py
Normal file
733
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,733 @@
|
||||||
|
# What is this?
|
||||||
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||||
|
## V0 - just covers cohere command-r support
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
TypedDict,
|
||||||
|
Tuple,
|
||||||
|
Iterator,
|
||||||
|
AsyncIterator,
|
||||||
|
)
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
get_secret,
|
||||||
|
Logging,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||||
|
from litellm.types.llms.bedrock import *
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonCohereChatConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
search_queries_only: Optional[bool] = None
|
||||||
|
preamble: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
p: Optional[float] = None
|
||||||
|
k: Optional[float] = None
|
||||||
|
prompt_truncation: Optional[str] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
return_prompt: Optional[bool] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
raw_prompting: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
search_queries_only: Optional[bool] = None,
|
||||||
|
preamble: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
p: Optional[float] = None,
|
||||||
|
k: Optional[float] = None,
|
||||||
|
prompt_truncation: Optional[str] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_prompt: Optional[bool] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
raw_prompting: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["p"] = value
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if "seed":
|
||||||
|
optional_params["seed"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockLLM(BaseLLM):
|
||||||
|
"""
|
||||||
|
Example call
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Accept: application/json' \
|
||||||
|
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||||
|
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||||
|
--data-raw '{
|
||||||
|
"prompt": "Hi",
|
||||||
|
"temperature": 0,
|
||||||
|
"p": 0.9,
|
||||||
|
"max_tokens": 4096
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def convert_messages_to_prompt(
|
||||||
|
self, model, messages, provider, custom_prompt_dict
|
||||||
|
) -> Tuple[str, Optional[list]]:
|
||||||
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
prompt = ""
|
||||||
|
chat_history: Optional[list] = None
|
||||||
|
if provider == "anthropic" or provider == "amazon":
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "mistral":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if "role" in message:
|
||||||
|
if message["role"] == "user":
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
return prompt, chat_history # type: ignore
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return sts_response["Credentials"]
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
prompt_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-input-token-count",
|
||||||
|
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
completion_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-output-token-count",
|
||||||
|
len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response.choices[0].message.content, # type: ignore
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = ""
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
prompt, chat_history = self.convert_messages_to_prompt(
|
||||||
|
model, messages, provider, custom_prompt_dict
|
||||||
|
)
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
if model.startswith("cohere.command-r"):
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereChatConfig().get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
_data = {"message": prompt, **inference_params}
|
||||||
|
if chat_history is not None:
|
||||||
|
_data["chat_history"] = chat_history
|
||||||
|
data = json.dumps(_data)
|
||||||
|
else:
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
if stream == True:
|
||||||
|
inference_params["stream"] = (
|
||||||
|
True # cohere requires stream = True in inference params
|
||||||
|
)
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
else:
|
||||||
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
response = self.client.post(
|
||||||
|
url=prepped.url,
|
||||||
|
headers=prepped.headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_stream_shape():
|
||||||
|
from botocore.model import ServiceModel
|
||||||
|
from botocore.loaders import Loader
|
||||||
|
|
||||||
|
loader = Loader()
|
||||||
|
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
|
||||||
|
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||||
|
return bedrock_service_model.shape_for("ResponseStream")
|
||||||
|
|
||||||
|
|
||||||
|
class AWSEventStreamDecoder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
async def aiter_bytes(
|
||||||
|
self, iterator: AsyncIterator[bytes]
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
async for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
|
response_dict = event.to_response_dict()
|
||||||
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
|
if response_dict["status_code"] != 200:
|
||||||
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
|
||||||
|
chunk = parsed_response.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
328
litellm/llms/clarifai.py
Normal file
328
litellm/llms/clarifai.py
Normal file
|
@ -0,0 +1,328 @@
|
||||||
|
import os, types, traceback
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional
|
||||||
|
from litellm.utils import ModelResponse, Usage, Choices, Message, CustomStreamWrapper
|
||||||
|
import litellm
|
||||||
|
import httpx
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class ClarifaiError(Exception):
|
||||||
|
def __init__(self, status_code, message, url):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url=url
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
)
|
||||||
|
|
||||||
|
class ClarifaiConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
|
||||||
|
TODO fill in the details
|
||||||
|
"""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_environment(api_key):
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def completions_to_model(payload):
|
||||||
|
# if payload["n"] != 1:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status_code=422,
|
||||||
|
# detail="Only one generation is supported. Please set candidate_count to 1.",
|
||||||
|
# )
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if temperature := payload.get("temperature"):
|
||||||
|
params["temperature"] = temperature
|
||||||
|
if max_tokens := payload.get("max_tokens"):
|
||||||
|
params["max_tokens"] = max_tokens
|
||||||
|
return {
|
||||||
|
"inputs": [{"data": {"text": {"raw": payload["prompt"]}}}],
|
||||||
|
"model": {"output_info": {"params": params}},
|
||||||
|
}
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
response,
|
||||||
|
model_response,
|
||||||
|
api_key,
|
||||||
|
data,
|
||||||
|
encoding,
|
||||||
|
logging_obj
|
||||||
|
):
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except Exception:
|
||||||
|
raise ClarifaiError(
|
||||||
|
message=response.text, status_code=response.status_code, url=model
|
||||||
|
)
|
||||||
|
# print(completion_response)
|
||||||
|
try:
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(completion_response["outputs"]):
|
||||||
|
if len(item["data"]["text"]["raw"]) > 0:
|
||||||
|
message_obj = Message(content=item["data"]["text"]["raw"])
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=idx + 1, #check
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"] = choices_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ClarifaiError(
|
||||||
|
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate Usage
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content"))
|
||||||
|
)
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def convert_model_to_url(model: str, api_base: str):
|
||||||
|
user_id, app_id, model_id = model.split(".")
|
||||||
|
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
|
||||||
|
|
||||||
|
def get_prompt_model_name(url: str):
|
||||||
|
clarifai_model_name = url.split("/")[-2]
|
||||||
|
if "claude" in clarifai_model_name:
|
||||||
|
return "anthropic", clarifai_model_name.replace("_", ".")
|
||||||
|
if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name):
|
||||||
|
return "", "meta-llama/llama-2-chat"
|
||||||
|
else:
|
||||||
|
return "", clarifai_model_name
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data=None,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={}):
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_response(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
custom_prompt_dict={},
|
||||||
|
acompletion=False,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
):
|
||||||
|
headers = validate_environment(api_key)
|
||||||
|
model = convert_model_to_url(model, api_base)
|
||||||
|
prompt = " ".join(message["content"] for message in messages) # TODO
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.ClarifaiConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
):
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
|
||||||
|
if custom_llm_provider == "anthropic":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=orig_model_name,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
custom_llm_provider="clarifai"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=orig_model_name,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
# print(prompt); exit(0)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
data = completions_to_model(data)
|
||||||
|
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if acompletion==True:
|
||||||
|
return async_completion(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
data=data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = requests.post(
|
||||||
|
model,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
# print(response.content); exit()
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ClarifaiError(status_code=response.status_code, message=response.text, url=model)
|
||||||
|
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
completion_stream = response.iter_lines()
|
||||||
|
stream_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="clarifai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return stream_response
|
||||||
|
|
||||||
|
else:
|
||||||
|
return process_response(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, model_response):
|
||||||
|
self.model_response = model_response
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
|
@ -1,11 +1,11 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class CohereError(Exception):
|
class CohereError(Exception):
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import cohere_message_pt
|
from .prompt_templates.factory import cohere_message_pt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,15 @@ class AsyncHTTPHandler:
|
||||||
|
|
||||||
class HTTPHandler:
|
class HTTPHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
self,
|
||||||
|
timeout: Optional[httpx.Timeout] = None,
|
||||||
|
concurrent_limit=1000,
|
||||||
|
client: Optional[httpx.Client] = None,
|
||||||
):
|
):
|
||||||
|
if timeout is None:
|
||||||
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
if client is None:
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.Client(
|
self.client = httpx.Client(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -68,6 +75,8 @@ class HTTPHandler:
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=concurrent_limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# Close the client when you're done with it
|
# Close the client when you're done with it
|
||||||
|
@ -82,11 +91,15 @@ class HTTPHandler:
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[dict] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
req = self.client.build_request(
|
||||||
|
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||||
|
)
|
||||||
|
response = self.client.send(req, stream=stream)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable, Dict, List, Any
|
from typing import Callable, Dict, List, Any, Literal
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
hf_task_list = [
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_tasks = Literal[
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceConfig:
|
class HuggingfaceConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[hf_tasks] = (
|
||||||
|
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
)
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
decoder_input_details: Optional[bool] = None
|
decoder_input_details: Optional[bool] = None
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
"""
|
"""
|
||||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_for_model(model):
|
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
return model.split("/")[0] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference"
|
||||||
elif model in conversational_models:
|
elif model in conversational_models:
|
||||||
return "conversational"
|
return "conversational"
|
||||||
elif "roneneldan/TinyStories" in model:
|
elif "roneneldan/TinyStories" in model:
|
||||||
return None
|
return "text-generation"
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference" # default to tgi
|
return "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
||||||
self,
|
self,
|
||||||
completion_response,
|
completion_response,
|
||||||
model_response,
|
model_response,
|
||||||
task,
|
task: hf_tasks,
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
input_text,
|
input_text,
|
||||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response["choices"].extend(choices_list)
|
model_response["choices"].extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||||
|
@ -322,9 +393,9 @@ class Huggingface(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
task = get_hf_task_for_model(model)
|
task = get_hf_task_for_model(model)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
print_verbose(f"{model}, {task}")
|
print_verbose(f"{model}, {task}")
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
@ -399,10 +476,11 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": (
|
"stream": ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and isinstance(optional_params["stream"], bool)
|
||||||
|
and optional_params["stream"] == True # type: ignore
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -432,14 +510,15 @@ class Huggingface(BaseLLM):
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
}
|
||||||
"stream": (
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
else False
|
else False
|
||||||
),
|
)
|
||||||
}
|
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -530,10 +609,10 @@ class Huggingface(BaseLLM):
|
||||||
isinstance(completion_response, dict)
|
isinstance(completion_response, dict)
|
||||||
and "error" in completion_response
|
and "error" in completion_response
|
||||||
):
|
):
|
||||||
print_verbose(f"completion error: {completion_response['error']}")
|
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||||
print_verbose(f"response.status_code: {response.status_code}")
|
print_verbose(f"response.status_code: {response.status_code}")
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
message=completion_response["error"],
|
message=completion_response["error"], # type: ignore
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
return self.convert_to_model_response_object(
|
return self.convert_to_model_response_object(
|
||||||
|
@ -562,7 +641,7 @@ class Huggingface(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
task: str,
|
task: hf_tasks,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import requests, types, time
|
from itertools import chain
|
||||||
|
import requests, types, time # type: ignore
|
||||||
import json, uuid
|
import json, uuid
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx, aiohttp, asyncio
|
import httpx, aiohttp, asyncio # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,25 +213,31 @@ def get_ollama_response(
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
model_response["choices"][0]["finish_reason"] = "stop"
|
model_response["choices"][0]["finish_reason"] = "stop"
|
||||||
if optional_params.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
function_call = json.loads(response_json["response"])
|
function_call = json.loads(response_json["response"])
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + model
|
model_response["model"] = "ollama/" + model
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
||||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
completion_tokens = response_json.get(
|
||||||
|
"eval_count", len(response_json.get("message", dict()).get("content", ""))
|
||||||
|
)
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
@ -255,6 +262,35 @@ def ollama_completion_stream(url, data, logging_obj):
|
||||||
custom_llm_provider="ollama",
|
custom_llm_provider="ollama",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = next(streamwrapper)
|
||||||
|
response_content = "".join(
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
for chunk in chain([first_chunk], streamwrapper)
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -278,6 +314,38 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
||||||
custom_llm_provider="ollama",
|
custom_llm_provider="ollama",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = await anext(streamwrapper)
|
||||||
|
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||||
|
response_content = first_chunk_content + "".join(
|
||||||
|
[
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
async for chunk in streamwrapper
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
]
|
||||||
|
)
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -317,12 +385,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"]["content"] = response_json[
|
model_response["choices"][0]["message"]["content"] = response_json[
|
||||||
"response"
|
"response"
|
||||||
|
@ -330,7 +402,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + data["model"]
|
model_response["model"] = "ollama/" + data["model"]
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
||||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
completion_tokens = response_json.get(
|
||||||
|
"eval_count",
|
||||||
|
len(response_json.get("message", dict()).get("content", "")),
|
||||||
|
)
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
@ -417,3 +492,25 @@ async def ollama_aembeddings(
|
||||||
"total_tokens": total_input_tokens,
|
"total_tokens": total_input_tokens,
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_embeddings(
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
prompts: list,
|
||||||
|
optional_params=None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
return asyncio.run(
|
||||||
|
ollama_aembeddings(
|
||||||
|
api_base,
|
||||||
|
model,
|
||||||
|
prompts,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model_response,
|
||||||
|
encoding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from itertools import chain
|
||||||
import requests, types, time
|
import requests, types, time
|
||||||
import json, uuid
|
import json, uuid
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -297,8 +298,9 @@ def get_ollama_response(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"] = response_json["message"]
|
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + model
|
model_response["model"] = "ollama/" + model
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
||||||
|
@ -335,6 +337,33 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
custom_llm_provider="ollama_chat",
|
custom_llm_provider="ollama_chat",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = next(streamwrapper)
|
||||||
|
response_content = "".join(
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
for chunk in chain([first_chunk], streamwrapper)
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
for transformed_chunk in streamwrapper:
|
for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -366,6 +395,34 @@ async def ollama_async_streaming(
|
||||||
custom_llm_provider="ollama_chat",
|
custom_llm_provider="ollama_chat",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = await anext(streamwrapper)
|
||||||
|
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||||
|
response_content = first_chunk_content + "".join(
|
||||||
|
[
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
async for chunk in streamwrapper
|
||||||
|
if chunk.choices[0].delta.content]
|
||||||
|
)
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -425,8 +482,9 @@ async def ollama_acompletion(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"] = response_json["message"]
|
model_response["choices"][0]["message"]["content"] = response_json["message"]["content"]
|
||||||
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama_chat/" + data["model"]
|
model_response["model"] = "ollama_chat/" + data["model"]
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
|
@ -1,4 +1,13 @@
|
||||||
from typing import Optional, Union, Any, BinaryIO
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Literal,
|
||||||
|
Iterable,
|
||||||
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
from pydantic import BaseModel
|
||||||
import types, time, json, traceback
|
import types, time, json, traceback
|
||||||
import httpx
|
import httpx
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -13,10 +22,10 @@ from litellm.utils import (
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import aiohttp, requests
|
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from ..types.llms.openai import *
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(Exception):
|
class OpenAIError(Exception):
|
||||||
|
@ -44,6 +53,113 @@ class OpenAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.mistral.ai/api/
|
||||||
|
|
||||||
|
The class `MistralConfig` provides configuration for the Mistral's Chat API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. API Default - 0.7.
|
||||||
|
|
||||||
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. API Default - 1.
|
||||||
|
|
||||||
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. API Default - null.
|
||||||
|
|
||||||
|
- `tools` (list or null): A list of available tools for the model. Use this to specify functions for which the model can generate JSON inputs.
|
||||||
|
|
||||||
|
- `tool_choice` (string - 'auto'/'any'/'none' or null): Specifies if/how functions are called. If set to none the model won't call a function and will generate a message instead. If set to auto the model can choose to either generate a message or call a function. If set to any the model is forced to call a function. Default - 'auto'.
|
||||||
|
|
||||||
|
- `random_seed` (integer or null): The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||||
|
|
||||||
|
- `safe_prompt` (boolean): Whether to inject a safety prompt before all conversations. API Default - 'false'.
|
||||||
|
|
||||||
|
- `response_format` (object or null): An object specifying the format that the model must output. Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
tools: Optional[list] = None
|
||||||
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None
|
||||||
|
random_seed: Optional[int] = None
|
||||||
|
safe_prompt: Optional[bool] = None
|
||||||
|
response_format: Optional[dict] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[Literal["auto", "any", "none"]] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
safe_prompt: Optional[bool] = None,
|
||||||
|
response_format: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"seed",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _map_tool_choice(self, tool_choice: str) -> str:
|
||||||
|
if tool_choice == "auto" or tool_choice == "none":
|
||||||
|
return tool_choice
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return "any"
|
||||||
|
else: # openai 'tool_choice' object param not supported by Mistral API
|
||||||
|
return "any"
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "tool_choice" and isinstance(value, str):
|
||||||
|
optional_params["tool_choice"] = self._map_tool_choice(
|
||||||
|
tool_choice=value
|
||||||
|
)
|
||||||
|
if param == "seed":
|
||||||
|
optional_params["extra_body"] = {"random_seed": value}
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfig:
|
class OpenAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
@ -246,7 +362,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
|
@ -271,9 +387,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
|
|
||||||
if not isinstance(timeout, float):
|
if not isinstance(timeout, float) and not isinstance(
|
||||||
|
timeout, httpx.Timeout
|
||||||
|
):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=422, message=f"Timeout needs to be a float"
|
status_code=422,
|
||||||
|
message=f"Timeout needs to be a float or httpx.Timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider != "openai":
|
if custom_llm_provider != "openai":
|
||||||
|
@ -425,7 +544,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
@ -480,7 +599,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -518,13 +637,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -567,6 +687,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
except (
|
except (
|
||||||
|
@ -1191,6 +1312,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in streamwrapper:
|
for chunk in streamwrapper:
|
||||||
|
@ -1229,7 +1351,228 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_openai_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> OpenAI:
|
||||||
|
received_args = locals()
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["base_url"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
openai_client = OpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
return openai_client
|
||||||
|
|
||||||
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.assistants.list()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### MESSAGES ###
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: MessageData,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
||||||
|
thread_id, **message_data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj: Optional[OpenAIMessage] = None
|
||||||
|
if getattr(thread_message, "status", None) is None:
|
||||||
|
thread_message.status = "completed"
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
else:
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
return response_obj
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### THREADS ###
|
||||||
|
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
) -> Thread:
|
||||||
|
"""
|
||||||
|
Here's an example:
|
||||||
|
```
|
||||||
|
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
|
||||||
|
|
||||||
|
# create thread
|
||||||
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||||
|
openai_api.create_thread(messages=[message])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if messages is not None:
|
||||||
|
data["messages"] = messages # type: ignore
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata # type: ignore
|
||||||
|
|
||||||
|
message_thread = openai_client.beta.threads.create(**data) # type: ignore
|
||||||
|
|
||||||
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
|
||||||
|
return Thread(**response.dict())
|
||||||
|
|
||||||
|
def delete_thread(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
### RUNS ###
|
||||||
|
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> Run:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
|
518
litellm/llms/predibase.py
Normal file
518
litellm/llms/predibase.py
Normal file
|
@ -0,0 +1,518 @@
|
||||||
|
# What is this?
|
||||||
|
## Controller file for Predibase Integration - https://predibase.com/
|
||||||
|
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List, Literal, Union
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
if request is not None:
|
||||||
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
adapter_id: Optional[str] = None
|
||||||
|
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: bool = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: int = (
|
||||||
|
256 # openai default - requests hang if max_new_tokens not given
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[bool] = (
|
||||||
|
False # by default don't return the input as part of the output
|
||||||
|
)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if user_headers is not None and isinstance(user_headers, dict):
|
||||||
|
headers = {**headers, **user_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def output_parser(self, generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = [
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise PredibaseError(message=response.text, status_code=422)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise PredibaseError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not isinstance(completion_response, dict)
|
||||||
|
or "generated_text" not in completion_response
|
||||||
|
):
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"response is not in expected format - {completion_response}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(completion_response["generated_text"]) > 0:
|
||||||
|
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||||
|
completion_response["generated_text"]
|
||||||
|
)
|
||||||
|
## GETTING LOGPROBS + FINISH REASON
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "tokens" in completion_response["details"]
|
||||||
|
):
|
||||||
|
model_response.choices[0].finish_reason = completion_response[
|
||||||
|
"details"
|
||||||
|
]["finish_reason"]
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response["details"]["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
model_response["choices"][0][
|
||||||
|
"message"
|
||||||
|
]._logprob = (
|
||||||
|
sum_logprob # [TODO] move this to using the actual logprobs
|
||||||
|
)
|
||||||
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "best_of_sequences" in completion_response["details"]
|
||||||
|
):
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(
|
||||||
|
completion_response["details"]["best_of_sequences"]
|
||||||
|
):
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in item["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
if len(item["generated_text"]) > 0:
|
||||||
|
message_obj = Message(
|
||||||
|
content=self.output_parser(item["generated_text"]),
|
||||||
|
logprobs=sum_logprob,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"].extend(choices_list)
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = 0
|
||||||
|
try:
|
||||||
|
prompt_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||||
|
) ##[TODO] use a model-specific tokenizer here
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
if output_text is not None and len(output_text) > 0:
|
||||||
|
completion_tokens = 0
|
||||||
|
try:
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
) ##[TODO] use a model-specific tokenizer
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage # type: ignore
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: str,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
tenant_id: str,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
headers = self._validate_environment(api_key, headers)
|
||||||
|
completion_url = ""
|
||||||
|
input_text = ""
|
||||||
|
base_url = "https://serving.app.predibase.com"
|
||||||
|
if "https" in model:
|
||||||
|
completion_url = model
|
||||||
|
elif api_base:
|
||||||
|
base_url = api_base
|
||||||
|
elif "PREDIBASE_API_BASE" in os.environ:
|
||||||
|
base_url = os.getenv("PREDIBASE_API_BASE", "")
|
||||||
|
|
||||||
|
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||||
|
|
||||||
|
if optional_params.get("stream", False) == True:
|
||||||
|
completion_url += "/generate_stream"
|
||||||
|
else:
|
||||||
|
completion_url += "/generate"
|
||||||
|
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.PredibaseConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_text,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if acompletion == True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### SYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
response = requests.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
_response = CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
logging_obj=logging_obj, # type: ignore
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> ModelResponse:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data: dict,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
data["stream"] = True
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
pass
|
|
@ -12,6 +12,16 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.types.completion import (
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionFunctionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.anthropic import *
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
|
@ -22,6 +32,41 @@ def prompt_injection_detection_default_pt():
|
||||||
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
||||||
|
|
||||||
|
|
||||||
|
def map_system_message_pt(messages: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
|
||||||
|
|
||||||
|
Enabled via `completion(...,supports_system_message=False)`
|
||||||
|
|
||||||
|
If next message is a user message or assistant message -> merge system prompt into it
|
||||||
|
|
||||||
|
if next message is system -> append a user message instead of the system message
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for i, m in enumerate(messages):
|
||||||
|
if m["role"] == "system":
|
||||||
|
if i < len(messages) - 1: # Not the last message
|
||||||
|
next_m = messages[i + 1]
|
||||||
|
next_role = next_m["role"]
|
||||||
|
if (
|
||||||
|
next_role == "user" or next_role == "assistant"
|
||||||
|
): # Next message is a user or assistant message
|
||||||
|
# Merge system prompt into the next message
|
||||||
|
next_m["content"] = m["content"] + " " + next_m["content"]
|
||||||
|
elif next_role == "system": # Next message is a system message
|
||||||
|
# Append a user message instead of the system message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Last message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Not a system message
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
# alpaca prompt template - for models like mythomax, etc.
|
# alpaca prompt template - for models like mythomax, etc.
|
||||||
def alpaca_pt(messages):
|
def alpaca_pt(messages):
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
|
@ -805,6 +850,13 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"content": "function result goes here",
|
"content": "function result goes here",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
OpenAI message with a function call result looks like:
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -821,6 +873,7 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
if message["role"] == "tool":
|
||||||
tool_call_id = message.get("tool_call_id")
|
tool_call_id = message.get("tool_call_id")
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
|
|
||||||
|
@ -831,8 +884,31 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
"tool_use_id": tool_call_id,
|
"tool_use_id": tool_call_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return anthropic_tool_result
|
return anthropic_tool_result
|
||||||
|
elif message["role"] == "function":
|
||||||
|
content = message.get("content")
|
||||||
|
anthropic_tool_result = {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": str(uuid.uuid4()),
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
return anthropic_tool_result
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_function_to_anthropic_tool_invoke(function_call):
|
||||||
|
try:
|
||||||
|
anthropic_tool_invoke = [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"name": get_attribute_or_key(function_call, "name"),
|
||||||
|
"input": json.loads(get_attribute_or_key(function_call, "arguments")),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return anthropic_tool_invoke
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||||
|
@ -895,7 +971,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||||
def anthropic_messages_pt(messages: list):
|
def anthropic_messages_pt(messages: list):
|
||||||
"""
|
"""
|
||||||
format messages for anthropic
|
format messages for anthropic
|
||||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
1. Anthropic supports roles like "user" and "assistant" (system prompt sent separately)
|
||||||
2. The first message always needs to be of role "user"
|
2. The first message always needs to be of role "user"
|
||||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||||
|
@ -903,12 +979,14 @@ def anthropic_messages_pt(messages: list):
|
||||||
6. Ensure we only accept role, content. (message.name is not supported)
|
6. Ensure we only accept role, content. (message.name is not supported)
|
||||||
"""
|
"""
|
||||||
# add role=tool support to allow function call result/error submission
|
# add role=tool support to allow function call result/error submission
|
||||||
user_message_types = {"user", "tool"}
|
user_message_types = {"user", "tool", "function"}
|
||||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
||||||
new_messages = []
|
new_messages: list = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
tool_use_param = False
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content = []
|
user_content = []
|
||||||
|
init_msg_i = msg_i
|
||||||
## MERGE CONSECUTIVE USER CONTENT ##
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||||
if isinstance(messages[msg_i]["content"], list):
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
@ -924,7 +1002,10 @@ def anthropic_messages_pt(messages: list):
|
||||||
)
|
)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
user_content.append({"type": "text", "text": m["text"]})
|
user_content.append({"type": "text", "text": m["text"]})
|
||||||
elif messages[msg_i]["role"] == "tool":
|
elif (
|
||||||
|
messages[msg_i]["role"] == "tool"
|
||||||
|
or messages[msg_i]["role"] == "function"
|
||||||
|
):
|
||||||
# OpenAI's tool message content will always be a string
|
# OpenAI's tool message content will always be a string
|
||||||
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
||||||
else:
|
else:
|
||||||
|
@ -953,11 +1034,24 @@ def anthropic_messages_pt(messages: list):
|
||||||
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if messages[msg_i].get("function_call"):
|
||||||
|
assistant_content.extend(
|
||||||
|
convert_function_to_anthropic_tool_invoke(
|
||||||
|
messages[msg_i]["function_call"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
messages[msg_i]
|
||||||
|
)
|
||||||
|
)
|
||||||
if not new_messages or new_messages[0]["role"] != "user":
|
if not new_messages or new_messages[0]["role"] != "user":
|
||||||
if litellm.modify_params:
|
if litellm.modify_params:
|
||||||
new_messages.insert(
|
new_messages.insert(
|
||||||
|
@ -969,6 +1063,9 @@ def anthropic_messages_pt(messages: list):
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_messages[-1]["role"] == "assistant":
|
if new_messages[-1]["role"] == "assistant":
|
||||||
|
if isinstance(new_messages[-1]["content"], str):
|
||||||
|
new_messages[-1]["content"] = new_messages[-1]["content"].rstrip()
|
||||||
|
elif isinstance(new_messages[-1]["content"], list):
|
||||||
for content in new_messages[-1]["content"]:
|
for content in new_messages[-1]["content"]:
|
||||||
if isinstance(content, dict) and content["type"] == "text":
|
if isinstance(content, dict) and content["type"] == "text":
|
||||||
content["text"] = content[
|
content["text"] = content[
|
||||||
|
@ -1412,6 +1509,11 @@ def prompt_factory(
|
||||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif custom_llm_provider == "clarifai":
|
||||||
|
if "claude" in model:
|
||||||
|
return anthropic_pt(messages=messages)
|
||||||
|
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity":
|
||||||
for message in messages:
|
for message in messages:
|
||||||
message.pop("name", None)
|
message.pop("name", None)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import os, types, traceback
|
import os, types, traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Any
|
from typing import Callable, Optional, Any
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ def completion(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
InferenceComponentName={model_id},
|
InferenceComponentName={model_id},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)
|
)
|
||||||
""" # type: ignore
|
""" # type: ignore
|
||||||
|
@ -321,7 +321,7 @@ def completion(
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)
|
)
|
||||||
""" # type: ignore
|
""" # type: ignore
|
||||||
|
@ -688,7 +688,7 @@ def embedding(
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)""" # type: ignore
|
)""" # type: ignore
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
|
|
@ -6,11 +6,11 @@ Reference: https://docs.together.ai/docs/openai-api-compatibility
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List
|
||||||
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TritonError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class TritonChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_base: str,
|
||||||
|
logging_obj=None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TritonError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
_text_response = response.text
|
||||||
|
|
||||||
|
logging_obj.post_call(original_response=_text_response)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
_outputs = _json_response["outputs"]
|
||||||
|
_output_data = _outputs[0]["data"]
|
||||||
|
_embedding_output = {
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": _output_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
|
model_response.data = [_embedding_output]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
api_base: str,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aembedding=None,
|
||||||
|
):
|
||||||
|
data_for_triton = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": input,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
|
||||||
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input="",
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": curl_string,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(
|
||||||
|
data=data_for_triton,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||||
|
)
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue