Merge branch 'main' into feat/add-azure-content-filter

This commit is contained in:
Krish Dholakia 2024-05-11 09:30:38 -07:00 committed by GitHub
commit bbe1300c5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
200 changed files with 19218 additions and 2966 deletions

View file

@ -1,4 +1,4 @@
version: 2.1 version: 4.3.4
jobs: jobs:
local_testing: local_testing:
docker: docker:
@ -188,7 +188,7 @@ jobs:
command: | command: |
docker run -d \ docker run -d \
-p 4000:4000 \ -p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e DATABASE_URL=$PROXY_DATABASE_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \ -e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \ -e REDIS_PASSWORD=$REDIS_PASSWORD \
@ -198,6 +198,7 @@ jobs:
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e AUTO_INFER_REGION=True \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \ -e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \ -e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
@ -208,9 +209,7 @@ jobs:
my-app:latest \ my-app:latest \
--config /app/config.yaml \ --config /app/config.yaml \
--port 4000 \ --port 4000 \
--num_workers 8 \
--detailed_debug \ --detailed_debug \
--run_gunicorn \
- run: - run:
name: Install curl and dockerize name: Install curl and dockerize
command: | command: |
@ -225,7 +224,7 @@ jobs:
background: true background: true
- run: - run:
name: Wait for app to be ready name: Wait for app to be ready
command: dockerize -wait http://localhost:4000 -timeout 1m command: dockerize -wait http://localhost:4000 -timeout 5m
- run: - run:
name: Run tests name: Run tests
command: | command: |

View file

@ -0,0 +1,51 @@
{
"name": "Python 3.11",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/python:3.11-bookworm",
// https://github.com/devcontainers/images/tree/main/src/python
// https://mcr.microsoft.com/en-us/product/devcontainers/python/tags
// "build": {
// "dockerfile": "Dockerfile",
// "context": ".."
// },
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Configure tool-specific properties.
"customizations": {
// Configure properties specific to VS Code.
"vscode": {
"settings": {},
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"GitHub.copilot",
"GitHub.copilot-chat"
]
}
},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
"forwardPorts": [4000],
"containerEnv": {
"LITELLM_LOG": "DEBUG"
},
// Use 'portsAttributes' to set default properties for specific forwarded ports.
// More info: https://containers.dev/implementors/json_reference/#port-attributes
"portsAttributes": {
"4000": {
"label": "LiteLLM Server",
"onAutoForward": "notify"
}
},
// More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "litellm",
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "pipx install poetry && poetry install -E extra_proxy -E proxy"
}

View file

@ -64,6 +64,11 @@ if __name__ == "__main__":
) # Replace with your repository's username and name ) # Replace with your repository's username and name
latest_release = repo.get_latest_release() latest_release = repo.get_latest_release()
print("got latest release: ", latest_release) print("got latest release: ", latest_release)
print(latest_release.title)
print(latest_release.tag_name)
release_version = latest_release.title
print("latest release body: ", latest_release.body) print("latest release body: ", latest_release.body)
print("markdown table: ", markdown_table) print("markdown table: ", markdown_table)
@ -74,8 +79,22 @@ if __name__ == "__main__":
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results") start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
existing_release_body = latest_release.body[:start_index] existing_release_body = latest_release.body[:start_index]
docker_run_command = f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
```
"""
print("docker run command: ", docker_run_command)
new_release_body = ( new_release_body = (
existing_release_body existing_release_body
+ docker_run_command
+ "\n\n" + "\n\n"
+ "### Don't want to maintain your internal proxy? get in touch 🎉" + "### Don't want to maintain your internal proxy? get in touch 🎉"
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" + "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"

4
.gitignore vendored
View file

@ -1,5 +1,6 @@
.venv .venv
.env .env
litellm/proxy/myenv/*
litellm_uuid.txt litellm_uuid.txt
__pycache__/ __pycache__/
*.pyc *.pyc
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
litellm/proxy/_new_secret_config.yaml litellm/proxy/_new_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/myenv/bin/activate
litellm/proxy/myenv/bin/Activate.ps1
myenv/*

View file

@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ | | [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ | | [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ | | [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ | | [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅ | [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | | [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |

Binary file not shown.

View file

@ -0,0 +1,15 @@
{
"$schema": "https://schema.management.azure.com/schemas/0.1.2-preview/CreateUIDefinition.MultiVm.json#",
"handler": "Microsoft.Azure.CreateUIDef",
"version": "0.1.2-preview",
"parameters": {
"config": {
"isWizard": false,
"basics": { }
},
"basics": [ ],
"steps": [ ],
"outputs": { },
"resourceTypes": [ ]
}
}

View file

@ -0,0 +1,63 @@
{
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
"contentVersion": "1.0.0.0",
"parameters": {
"imageName": {
"type": "string",
"defaultValue": "ghcr.io/berriai/litellm:main-latest"
},
"containerName": {
"type": "string",
"defaultValue": "litellm-container"
},
"dnsLabelName": {
"type": "string",
"defaultValue": "litellm"
},
"portNumber": {
"type": "int",
"defaultValue": 4000
}
},
"resources": [
{
"type": "Microsoft.ContainerInstance/containerGroups",
"apiVersion": "2021-03-01",
"name": "[parameters('containerName')]",
"location": "[resourceGroup().location]",
"properties": {
"containers": [
{
"name": "[parameters('containerName')]",
"properties": {
"image": "[parameters('imageName')]",
"resources": {
"requests": {
"cpu": 1,
"memoryInGB": 2
}
},
"ports": [
{
"port": "[parameters('portNumber')]"
}
]
}
}
],
"osType": "Linux",
"restartPolicy": "Always",
"ipAddress": {
"type": "Public",
"ports": [
{
"protocol": "tcp",
"port": "[parameters('portNumber')]"
}
],
"dnsNameLabel": "[parameters('dnsLabelName')]"
}
}
}
]
}

View file

@ -0,0 +1,42 @@
param imageName string = 'ghcr.io/berriai/litellm:main-latest'
param containerName string = 'litellm-container'
param dnsLabelName string = 'litellm'
param portNumber int = 4000
resource containerGroupName 'Microsoft.ContainerInstance/containerGroups@2021-03-01' = {
name: containerName
location: resourceGroup().location
properties: {
containers: [
{
name: containerName
properties: {
image: imageName
resources: {
requests: {
cpu: 1
memoryInGB: 2
}
}
ports: [
{
port: portNumber
}
]
}
}
]
osType: 'Linux'
restartPolicy: 'Always'
ipAddress: {
type: 'Public'
ports: [
{
protocol: 'tcp'
port: portNumber
}
]
dnsNameLabel: dnsLabelName
}
}
}

View file

@ -24,7 +24,7 @@ version: 0.2.0
# incremented each time you make changes to the application. Versions are not expected to # incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using. # follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes. # It is recommended to use it with quotes.
appVersion: v1.24.5 appVersion: v1.35.38
dependencies: dependencies:
- name: "postgresql" - name: "postgresql"

View file

@ -83,8 +83,9 @@ def completion(
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[float] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
@ -139,6 +140,10 @@ def completion(
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message. - `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens. - `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion. - `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.

View file

@ -1,7 +1,7 @@
# Completion Token Usage & Cost # Completion Token Usage & Cost
By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/)) By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/))
However, we also expose 5 helper functions + **[NEW]** an API to calculate token usage across providers: However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers:
- `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode) - `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode)
@ -9,17 +9,19 @@ However, we also expose 5 helper functions + **[NEW]** an API to calculate token
- `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter) - `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter)
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#4-cost_per_token) - `create_pretrained_tokenizer` and `create_tokenizer`: LiteLLM provides default tokenizer support for OpenAI, Cohere, Anthropic, Llama2, and Llama3 models. If you are using a different model, you can create a custom tokenizer and pass it as `custom_tokenizer` to the `encode`, `decode`, and `token_counter` methods. [**Jump to code**](#4-create_pretrained_tokenizer-and-create_tokenizer)
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#5-completion_cost) - `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#5-cost_per_token)
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#6-get_max_tokens) - `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#6-completion_cost)
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#7-model_cost) - `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#7-get_max_tokens)
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#8-register_model) - `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#8-model_cost)
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#9-apilitellmai) - `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#9-register_model)
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai)
📣 This is a community maintained list. Contributions are welcome! ❤️ 📣 This is a community maintained list. Contributions are welcome! ❤️
@ -60,7 +62,24 @@ messages = [{"user": "role", "content": "Hey, how's it going"}]
print(token_counter(model="gpt-3.5-turbo", messages=messages)) print(token_counter(model="gpt-3.5-turbo", messages=messages))
``` ```
### 4. `cost_per_token` ### 4. `create_pretrained_tokenizer` and `create_tokenizer`
```python
from litellm import create_pretrained_tokenizer, create_tokenizer
# get tokenizer from huggingface repo
custom_tokenizer_1 = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
# use tokenizer from json file
with open("tokenizer.json") as f:
json_data = json.load(f)
json_str = json.dumps(json_data)
custom_tokenizer_2 = create_tokenizer(json_str)
```
### 5. `cost_per_token`
```python ```python
from litellm import cost_per_token from litellm import cost_per_token
@ -72,7 +91,7 @@ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_toke
print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar) print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar)
``` ```
### 5. `completion_cost` ### 6. `completion_cost`
* Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings * Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings
* Output: Returns a `float` of cost for the `completion` call * Output: Returns a `float` of cost for the `completion` call
@ -99,7 +118,7 @@ cost = completion_cost(model="bedrock/anthropic.claude-v2", prompt="Hey!", compl
formatted_string = f"${float(cost):.10f}" formatted_string = f"${float(cost):.10f}"
print(formatted_string) print(formatted_string)
``` ```
### 6. `get_max_tokens` ### 7. `get_max_tokens`
Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list). Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list).
Output: Returns the maximum number of tokens allowed for the given model Output: Returns the maximum number of tokens allowed for the given model
@ -112,7 +131,7 @@ model = "gpt-3.5-turbo"
print(get_max_tokens(model)) # Output: 4097 print(get_max_tokens(model)) # Output: 4097
``` ```
### 7. `model_cost` ### 8. `model_cost`
* Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) * Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
@ -122,7 +141,7 @@ from litellm import model_cost
print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...} print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...}
``` ```
### 8. `register_model` ### 9. `register_model`
* Input: Provide EITHER a model cost dictionary or a url to a hosted json blob * Input: Provide EITHER a model cost dictionary or a url to a hosted json blob
* Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details. * Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details.
@ -157,5 +176,3 @@ export LITELLM_LOCAL_MODEL_COST_MAP="True"
``` ```
Note: this means you will need to upgrade to get updated pricing, and newer models. Note: this means you will need to upgrade to get updated pricing, and newer models.

View file

@ -320,8 +320,6 @@ from litellm import embedding
litellm.vertex_project = "hardy-device-38811" # Your Project ID litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location litellm.vertex_location = "us-central1" # proj location
os.environ['VOYAGE_API_KEY'] = ""
response = embedding( response = embedding(
model="vertex_ai/textembedding-gecko", model="vertex_ai/textembedding-gecko",
input=["good morning from litellm"], input=["good morning from litellm"],

View file

@ -17,6 +17,14 @@ This covers:
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md) - ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
## [COMING SOON] AWS Marketplace Support
Deploy managed LiteLLM Proxy within your VPC.
Includes all enterprise features.
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## Frequently Asked Questions ## Frequently Asked Questions
### What topics does Professional support cover and what SLAs do you offer? ### What topics does Professional support cover and what SLAs do you offer?

View file

@ -13,7 +13,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
| >=500 | InternalServerError | | >=500 | InternalServerError |
| N/A | ContextWindowExceededError| | N/A | ContextWindowExceededError|
| 400 | ContentPolicyViolationError| | 400 | ContentPolicyViolationError|
| N/A | APIConnectionError | | 500 | APIConnectionError |
Base case we return APIConnectionError Base case we return APIConnectionError
@ -74,6 +74,28 @@ except Exception as e:
``` ```
## Usage - Should you retry exception?
```
import litellm
import openai
try:
response = litellm.completion(
model="gpt-4",
messages=[
{
"role": "user",
"content": "hello, write a 20 pageg essay"
}
],
timeout=0.01, # this will raise a timeout exception
)
except openai.APITimeoutError as e:
should_retry = litellm._should_retry(e.status_code)
print(f"should_retry: {should_retry}")
```
## Details ## Details
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217) To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
@ -86,21 +108,34 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
Base case - we return the original exception. Base case - we return the original exception.
| | ContextWindowExceededError | AuthenticationError | InvalidRequestError | RateLimitError | ServiceUnavailableError | | custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------| |----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
| Anthropic | ✅ | ✅ | ✅ | ✅ | | | openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| OpenAI | ✅ | ✅ |✅ |✅ |✅| | text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅| | custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ | | openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ | | anthropic | ✓ | ✓ | ✓ | ✓ | | ✓ | | | ✓ | ✓ | |
| Huggingface | ✅ | ✅ | ✅ | ✅ | | | replicate | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | | |
| Openrouter | ✅ | ✅ | ✅ | ✅ | | | bedrock | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | ✓ | |
| AI21 | ✅ | ✅ | ✅ | ✅ | | | sagemaker | | ✓ | ✓ | | | | | | | | |
| VertexAI | | |✅ | | | | vertex_ai | ✓ | | ✓ | | | | ✓ | | | | ✓ |
| Bedrock | | |✅ | | | | palm | ✓ | ✓ | | | | | ✓ | | | | |
| Sagemaker | | |✅ | | | | gemini | ✓ | ✓ | | | | | ✓ | | | | |
| TogetherAI | ✅ | ✅ | ✅ | ✅ | | | cloudflare | | | ✓ | | | ✓ | | | | | |
| AlephAlpha | ✅ | ✅ | ✅ | ✅ | ✅ | | cohere | | ✓ | ✓ | | | ✓ | | | ✓ | | |
| cohere_chat | | ✓ | ✓ | | | ✓ | | | ✓ | | |
| huggingface | ✓ | ✓ | ✓ | | | ✓ | | ✓ | ✓ | | |
| ai21 | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | | | |
| nlp_cloud | ✓ | ✓ | ✓ | | | ✓ | ✓ | ✓ | ✓ | | |
| together_ai | ✓ | ✓ | ✓ | | | ✓ | | | | | |
| aleph_alpha | | | ✓ | | | ✓ | | | | | |
| ollama | ✓ | | ✓ | | | | | | ✓ | | |
| ollama_chat | ✓ | | ✓ | | | | | | ✓ | | |
| vllm | | | | | | ✓ | ✓ | | | | |
| azure | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | | ✓ | | |
- "✓" indicates that the specified `custom_llm_provider` can raise the corresponding exception.
- Empty cells indicate the lack of association or that the provider does not raise that particular exception type as indicated by the function.
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights. > For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.

View file

@ -47,3 +47,12 @@ Pricing is based on usage. We can figure out a price that works for your team, o
<Image img={require('../img/litellm_hosted_ui_router.png')} /> <Image img={require('../img/litellm_hosted_ui_router.png')} />
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) #### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
## Feature List
- Easy way to add/remove models
- 100% uptime even when models are added/removed
- custom callback webhooks
- your domain name with HTTPS
- Ability to create/delete User API keys
- Reasonable set monthly cost

View file

@ -14,14 +14,14 @@ import TabItem from '@theme/TabItem';
```python ```python
import os import os
from langchain.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain.prompts.chat import ( from langchain_core.prompts import (
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
AIMessagePromptTemplate, AIMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
os.environ['OPENAI_API_KEY'] = "" os.environ['OPENAI_API_KEY'] = ""
chat = ChatLiteLLM(model="gpt-3.5-turbo") chat = ChatLiteLLM(model="gpt-3.5-turbo")
@ -30,7 +30,7 @@ messages = [
content="what model are you" content="what model are you"
) )
] ]
chat(messages) chat.invoke(messages)
``` ```
</TabItem> </TabItem>
@ -39,14 +39,14 @@ chat(messages)
```python ```python
import os import os
from langchain.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain.prompts.chat import ( from langchain_core.prompts import (
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
AIMessagePromptTemplate, AIMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
os.environ['ANTHROPIC_API_KEY'] = "" os.environ['ANTHROPIC_API_KEY'] = ""
chat = ChatLiteLLM(model="claude-2", temperature=0.3) chat = ChatLiteLLM(model="claude-2", temperature=0.3)
@ -55,7 +55,7 @@ messages = [
content="what model are you" content="what model are you"
) )
] ]
chat(messages) chat.invoke(messages)
``` ```
</TabItem> </TabItem>
@ -64,14 +64,14 @@ chat(messages)
```python ```python
import os import os
from langchain.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain.prompts.chat import ( from langchain_core.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
AIMessagePromptTemplate, AIMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
os.environ['REPLICATE_API_TOKEN'] = "" os.environ['REPLICATE_API_TOKEN'] = ""
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1") chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
@ -80,7 +80,7 @@ messages = [
content="what model are you?" content="what model are you?"
) )
] ]
chat(messages) chat.invoke(messages)
``` ```
</TabItem> </TabItem>
@ -89,14 +89,14 @@ chat(messages)
```python ```python
import os import os
from langchain.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain.prompts.chat import ( from langchain_core.prompts import (
ChatPromptTemplate, ChatPromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
AIMessagePromptTemplate, AIMessagePromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
os.environ['COHERE_API_KEY'] = "" os.environ['COHERE_API_KEY'] = ""
chat = ChatLiteLLM(model="command-nightly") chat = ChatLiteLLM(model="command-nightly")
@ -105,32 +105,9 @@ messages = [
content="what model are you?" content="what model are you?"
) )
] ]
chat(messages) chat.invoke(messages)
``` ```
</TabItem>
<TabItem value="palm" label="PaLM - Google">
```python
import os
from langchain.chat_models import ChatLiteLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import AIMessage, HumanMessage, SystemMessage
os.environ['PALM_API_KEY'] = ""
chat = ChatLiteLLM(model="palm/chat-bison")
messages = [
HumanMessage(
content="what model are you?"
)
]
chat(messages)
```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -94,9 +94,10 @@ print(response)
``` ```
### Set Custom Trace ID, Trace User ID and Tags ### Set Custom Trace ID, Trace User ID, Trace Metadata, Trace Version, Trace Release and Tags
Pass `trace_id`, `trace_user_id`, `trace_metadata`, `trace_version`, `trace_release`, `tags` in `metadata`
Pass `trace_id`, `trace_user_id` in `metadata`
```python ```python
import litellm import litellm
@ -121,12 +122,20 @@ response = completion(
metadata={ metadata={
"generation_name": "ishaan-test-generation", # set langfuse Generation Name "generation_name": "ishaan-test-generation", # set langfuse Generation Name
"generation_id": "gen-id22", # set langfuse Generation ID "generation_id": "gen-id22", # set langfuse Generation ID
"version": "test-generation-version" # set langfuse Generation Version
"trace_user_id": "user-id2", # set langfuse Trace User ID "trace_user_id": "user-id2", # set langfuse Trace User ID
"session_id": "session-1", # set langfuse Session ID "session_id": "session-1", # set langfuse Session ID
"tags": ["tag1", "tag2"] # set langfuse Tags "tags": ["tag1", "tag2"], # set langfuse Tags
"trace_id": "trace-id22", # set langfuse Trace ID "trace_id": "trace-id22", # set langfuse Trace ID
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
"trace_release": "test-trace-release", # set langfuse Trace Release
### OR ### ### OR ###
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name "existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
### OR enforce that certain fields are trace overwritten in the trace during the continuation ###
"existing_trace_id": "trace-id22",
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
}, },
) )
@ -134,6 +143,38 @@ print(response)
``` ```
### Trace & Generation Parameters
#### Trace Specific Parameters
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default
* `trace_name` - Name of the trace, auto-generated by default
* `session_id` - Session identifier for the trace, defaults to `None`
* `trace_version` - Version for the trace, defaults to value for `version`
* `trace_release` - Release for the trace, defaults to `None`
* `trace_metadata` - Metadata for the trace, defaults to `None`
* `trace_user_id` - User identifier for the trace, defaults to completion argument `user`
* `tags` - Tags for the trace, defeaults to `None`
##### Updatable Parameters on Continuation
The following parameters can be updated on a continuation of a trace by passing in the following values into the `update_trace_keys` in the metadata of the completion.
* `input` - Will set the traces input to be the input of this latest generation
* `output` - Will set the traces output to be the output of this generation
* `trace_version` - Will set the trace version to be the provided value (To use the latest generations version instead, use `version`)
* `trace_release` - Will set the trace release to be the provided value
* `trace_metadata` - Will set the trace metadata to the provided value
* `trace_user_id` - Will set the trace user id to the provided value
#### Generation Specific Parameters
* `generation_id` - Identifier for the generation, auto-generated by default
* `generation_name` - Identifier for the generation, auto-generated by default
* `prompt` - Langfuse prompt object used for the generation, defaults to None
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
### Use LangChain ChatLiteLLM + Langfuse ### Use LangChain ChatLiteLLM + Langfuse
Pass `trace_user_id`, `session_id` in model_kwargs Pass `trace_user_id`, `session_id` in model_kwargs
```python ```python

View file

@ -535,7 +535,8 @@ print(response)
| Model Name | Function Call | | Model Name | Function Call |
|----------------------|---------------------------------------------| |----------------------|---------------------------------------------|
| Titan Embeddings - G1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` | | Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` |
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` | | Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` | | Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |

View file

@ -0,0 +1,54 @@
# Deepseek
https://deepseek.com/
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
## API Key
```python
# env variable
os.environ['DEEPSEEK_API_KEY']
```
## Sample Usage
```python
from litellm import completion
import os
os.environ['DEEPSEEK_API_KEY'] = ""
response = completion(
model="deepseek/deepseek-chat",
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
print(response)
```
## Sample Usage - Streaming
```python
from litellm import completion
import os
os.environ['DEEPSEEK_API_KEY'] = ""
response = completion(
model="deepseek/deepseek-chat",
messages=[
{"role": "user", "content": "hello from litellm"}
],
stream=True
)
for chunk in response:
print(chunk)
```
## Supported Models - ALL Deepseek Models Supported!
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs> <Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)"> <TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem> </TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)"> <TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/facebook/blenderbot-400M-distill", model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages, messages=messages,
api_base="https://my-endpoint.huggingface.cloud" api_base="https://my-endpoint.huggingface.cloud"
) )
@ -62,7 +116,123 @@ response = completion(
print(response) print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs"> <TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python ```python
import os import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/roneneldan/TinyStories-3M", model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages, messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
) )

View file

@ -45,13 +45,13 @@ for chunk in response:
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json). All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |----------------|--------------------------------------------------------------|
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` | | Mistral Small | `completion(model="mistral/mistral-small-latest", messages)` |
| mistral-small | `completion(model="mistral/mistral-small", messages)` | | Mistral Medium | `completion(model="mistral/mistral-medium-latest", messages)`|
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` | | Mistral Large | `completion(model="mistral/mistral-large-latest", messages)` |
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` | | Mistral 7B | `completion(model="mistral/open-mistral-7b", messages)` |
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` | | Mixtral 8x7B | `completion(model="mistral/open-mixtral-8x7b", messages)` |
| Mixtral 8x22B | `completion(model="mistral/open-mixtral-8x22b", messages)` |
## Function Calling ## Function Calling
@ -116,6 +116,6 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported
| Model Name | Function Call | | Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| mistral-embed | `embedding(model="mistral/mistral-embed", input)` | | Mistral Embeddings | `embedding(model="mistral/mistral-embed", input)` |

View file

@ -0,0 +1,247 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 🆕 Predibase
LiteLLM supports all models on Predibase
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### API KEYS
```python
import os
os.environ["PREDIBASE_API_KEY"] = ""
```
### Example Call
```python
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
# predibase llama-3 call
response = completion(
model="predibase/llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
client = openai.OpenAI(
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
)
response = client.chat.completions.create(
model="llama-3",
messages = [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
]
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama-3",
"messages": [
{
"role": "system",
"content": "Be a good human!"
},
{
"role": "user",
"content": "What do you know about earth?"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Advanced Usage - Prompt Formatting
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
To apply a custom prompt template:
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
import os
os.environ["PREDIBASE_API_KEY"] = ""
# Create your own custom prompt template
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
}
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
def predibase_custom_model():
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
response = completion(model=model, messages=messages)
print(response['choices'][0]['message']['content'])
return response
predibase_custom_model()
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
# Model-specific parameters
model_list:
- model_name: mistral-7b # model alias
litellm_params: # actual params for litellm.completion()
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
api_key: os.environ/PREDIBASE_API_KEY
initial_prompt_value: "\n"
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
final_prompt_value: "\n"
bos_token: "<s>"
eos_token: "</s>"
max_tokens: 4096
```
</TabItem>
</Tabs>
## Passing additional params - max_tokens, temperature
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
# predibae llama-3 call
response = completion(
model="predibase/llama3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
max_tokens=20,
temperature=0.5
)
```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
max_tokens: 20
temperature: 0.5
```
## Passings Predibase specific params - adapter_id, adapter_source,
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
```python
# !pip install litellm
from litellm import completion
import os
## set ENV variables
os.environ["PREDIBASE_API_KEY"] = "predibase key"
# predibase llama3 call
response = completion(
model="predibase/llama-3-8b-instruct",
messages = [{ "content": "Hello, how are you?","role": "user"}],
adapter_id="my_repo/3",
adapter_soruce="pbase",
)
```
**proxy**
```yaml
model_list:
- model_name: llama-3
litellm_params:
model: predibase/llama-3-8b-instruct
api_key: os.environ/PREDIBASE_API_KEY
adapter_id: my_repo/3
adapter_source: pbase
```

View file

@ -0,0 +1,95 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Triton Inference Server
LiteLLM supports Embedding Models on Triton Inference Servers
## Usage
<Tabs>
<TabItem value="sdk" label="SDK">
### Example Call
Use the `triton/` prefix to route to triton server
```python
from litellm import embedding
import os
response = await litellm.aembedding(
model="triton/<your-triton-model>",
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
input=["good morning from litellm"],
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: my-triton-model
litellm_params:
model: triton/<your-triton-model>"
api_base: https://your-triton-api-base/triton/embeddings
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --detailed_debug
```
3. Send Request to LiteLLM Proxy Server
<Tabs>
<TabItem value="openai" label="OpenAI Python v1.0.0+">
```python
import openai
from openai import OpenAI
# set base_url to your proxy server
# set api_key to send to proxy server
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
input=["hello from litellm"],
model="my-triton-model"
)
print(response)
```
</TabItem>
<TabItem value="curl" label="curl">
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
```shell
curl --location 'http://0.0.0.0:4000/embeddings' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data ' {
"model": "my-triton-model",
"input": ["write a litellm poem"]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>

View file

@ -477,6 +477,36 @@ print(response)
| code-gecko@latest| `completion('code-gecko@latest', messages)` | | code-gecko@latest| `completion('code-gecko@latest', messages)` |
## Embedding Models
#### Usage - Embedding
```python
import litellm
from litellm import embedding
litellm.vertex_project = "hardy-device-38811" # Your Project ID
litellm.vertex_location = "us-central1" # proj location
response = embedding(
model="vertex_ai/textembedding-gecko",
input=["good morning from litellm"],
)
print(response)
```
#### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
## Extra ## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS` ### Using `GOOGLE_APPLICATION_CREDENTIALS`
@ -520,6 +550,12 @@ def load_vertex_ai_credentials():
### Using GCP Service Account ### Using GCP Service Account
:::info
Trying to deploy LiteLLM on Google Cloud Run? Tutorial [here](https://docs.litellm.ai/docs/proxy/deploy#deploy-on-google-cloud-run)
:::
1. Figure out the Service Account bound to the Google Cloud Run service 1. Figure out the Service Account bound to the Google Cloud Run service
<Image img={require('../../img/gcp_acc_1.png')} /> <Image img={require('../../img/gcp_acc_1.png')} />

View file

@ -0,0 +1,83 @@
# Region-based Routing
Route specific customers to eu-only models.
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
### 1. Create customer with region-specification
Use the litellm 'end-user' object for this.
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
```bash
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id" : "ishaan-jaff-45",
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
}'
```
### 2. Add eu models to model-group
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-turbo-eu # 👈 EU azure model
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY
router_settings:
enable_pre_call_checks: true # 👈 IMPORTANT
```
Start the proxy
```yaml
litellm --config /path/to/config.yaml
```
### 3. Test it!
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
```bash
curl -X POST --location 'http://localhost:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what is the meaning of the universe? 1234"
}],
"user": "ishaan-jaff-45" # 👈 USER ID
}
'
```
Expected API Base in response headers
```
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
```
### FAQ
**What happens if there are no available models for that region?**
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.

View file

@ -915,39 +915,72 @@ Test Request
litellm --test litellm --test
``` ```
## Logging Proxy Input/Output Traceloop (OpenTelemetry) ## Logging Proxy Input/Output in OpenTelemetry format using Traceloop's OpenLLMetry
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format [OpenLLMetry](https://github.com/traceloop/openllmetry) _(built and maintained by Traceloop)_ is a set of extensions
built on top of [OpenTelemetry](https://opentelemetry.io/) that gives you complete observability over your LLM
application. Because it uses OpenTelemetry under the
hood, [it can be connected to various observability solutions](https://www.traceloop.com/docs/openllmetry/integrations/introduction)
like:
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop * [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop)
* [Axiom](https://www.traceloop.com/docs/openllmetry/integrations/axiom)
* [Azure Application Insights](https://www.traceloop.com/docs/openllmetry/integrations/azure)
* [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
* [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace)
* [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana)
* [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb)
* [HyperDX](https://www.traceloop.com/docs/openllmetry/integrations/hyperdx)
* [Instana](https://www.traceloop.com/docs/openllmetry/integrations/instana)
* [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic)
* [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
* [Service Now Cloud Observability](https://www.traceloop.com/docs/openllmetry/integrations/service-now)
* [Sentry](https://www.traceloop.com/docs/openllmetry/integrations/sentry)
* [SigNoz](https://www.traceloop.com/docs/openllmetry/integrations/signoz)
* [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk)
**Step 1** Install traceloop-sdk and set Traceloop API key We will use the `--config` to set `litellm.success_callback = ["traceloop"]` to achieve this, steps are listed below.
**Step 1:** Install the SDK
```shell ```shell
pip install traceloop-sdk -U pip install traceloop-sdk
``` ```
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog) **Step 2:** Configure Environment Variable for trace exporting
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
You will need to configure where to export your traces. Environment variables will control this, example: For Traceloop
you should use `TRACELOOP_API_KEY`, whereas for Datadog you use `TRACELOOP_BASE_URL`. For more
visit [the Integrations Catalog](https://www.traceloop.com/docs/openllmetry/integrations/introduction).
If you are using Datadog as the observability solutions then you can set `TRACELOOP_BASE_URL` as:
```shell
TRACELOOP_BASE_URL=http://<datadog-agent-hostname>:4318
```
**Step 3**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
```yaml ```yaml
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_key: my-fake-key # replace api_key with actual key
litellm_settings: litellm_settings:
success_callback: [ "traceloop" ] success_callback: [ "traceloop" ]
``` ```
**Step 3**: Start the proxy, make a test request **Step 4**: Start the proxy, make a test request
Start proxy Start proxy
```shell ```shell
litellm --config config.yaml --debug litellm --config config.yaml --debug
``` ```
Test Request Test Request
``` ```
curl --location 'http://0.0.0.0:4000/chat/completions' \ curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \

View file

@ -3,34 +3,38 @@ import TabItem from '@theme/TabItem';
# ⚡ Best Practices for Production # ⚡ Best Practices for Production
Expected Performance in Production ## 1. Use this config.yaml
Use this config.yaml in production (with your own LLMs)
1 LiteLLM Uvicorn Worker on Kubernetes
| Description | Value |
|--------------|-------|
| Avg latency | `50ms` |
| Median latency | `51ms` |
| `/chat/completions` Requests/second | `35` |
| `/chat/completions` Requests/minute | `2100` |
| `/chat/completions` Requests/hour | `126K` |
## 1. Switch off Debug Logging
Remove `set_verbose: True` from your config.yaml
```yaml ```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234 # enter your own master key, ensure it starts with 'sk-'
alerting: ["slack"] # Setup slack alerting - get alerts on LLM exceptions, Budget Alerts, Slow LLM Responses
proxy_batch_write_at: 60 # Batch write spend updates every 60s
litellm_settings: litellm_settings:
set_verbose: True set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
``` ```
You should only see the following level of details in logs on the proxy server Set slack webhook url in your env
```shell ```shell
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/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"
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
``` ```
:::info
Need Help or want dedicated support ? Talk to a founder [here]: (https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD] ## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
@ -40,21 +44,12 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"] CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
``` ```
## 3. Batch write spend updates every 60s
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally. ## 3. Use Redis 'port','host', 'password'. NOT 'redis_url'
In production, we recommend using a longer interval period of 60s. This reduces the number of connections used to make DB writes. If you decide to use Redis, DO NOT use 'redis_url'. We recommend usig redis port, host, and password params.
```yaml `redis_url`is 80 RPS slower
general_settings:
master_key: sk-1234
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
```
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188) This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
@ -69,103 +64,31 @@ router_settings:
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
``` ```
## 5. Switch off resetting budgets ## Extras
### Expected Performance in Production
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database) 1 LiteLLM Uvicorn Worker on Kubernetes
```yaml
general_settings:
disable_reset_budget: true
```
## 6. Move spend logs to separate server (BETA) | Description | Value |
|--------------|-------|
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server. | Avg latency | `50ms` |
| Median latency | `51ms` |
👉 [LiteLLM Spend Logs Server](https://github.com/BerriAI/litellm/tree/main/litellm-js/spend-logs) | `/chat/completions` Requests/second | `35` |
| `/chat/completions` Requests/minute | `2100` |
| `/chat/completions` Requests/hour | `126K` |
**Spend Logs** ### Verifying Debugging logs are off
This is a log of the key, tokens, model, and latency for each call on the proxy.
[**Full Payload**](https://github.com/BerriAI/litellm/blob/8c9623a6bc4ad9da0a2dac64249a60ed8da719e8/litellm/proxy/utils.py#L1769) You should only see the following level of details in logs on the proxy server
```shell
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
**1. Start the spend logs server** # INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
```bash
docker run -p 3000:3000 \
-e DATABASE_URL="postgres://.." \
ghcr.io/berriai/litellm-spend_logs:main-latest
# RUNNING on http://0.0.0.0:3000
```
**2. Connect to proxy**
Example litellm_config.yaml
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
```
Add `SPEND_LOGS_URL` as an environment variable when starting the proxy
```bash
docker run \
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
-e DATABASE_URL="postgresql://.." \
-e SPEND_LOGS_URL="http://host.docker.internal:3000" \ # 👈 KEY CHANGE
-p 4000:4000 \
ghcr.io/berriai/litellm:main-latest \
--config /app/config.yaml --detailed_debug
# Running on http://0.0.0.0:4000
```
**3. Test Proxy!**
```bash
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "fake-openai-endpoint",
"messages": [
{"role": "system", "content": "Be helpful"},
{"role": "user", "content": "What do you know?"}
]
}'
```
In your LiteLLM Spend Logs Server, you should see
**Expected Response**
```
Received and stored 1 logs. Total logs in memory: 1
...
Flushed 1 log to the DB.
``` ```
### Machine Specification ### Machine Specifications to Deploy LiteLLM
A t2.micro should be sufficient to handle 1k logs / minute on this server.
This consumes at max 120MB, and <0.1 vCPU.
## Machine Specifications to Deploy LiteLLM
| Service | Spec | CPUs | Memory | Architecture | Version| | Service | Spec | CPUs | Memory | Architecture | Version|
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
@ -173,7 +96,7 @@ This consumes at max 120MB, and <0.1 vCPU.
| Redis Cache | - | - | - | - | 7.0+ Redis Engine| | Redis Cache | - | - | - | - | 7.0+ Redis Engine|
## Reference Kubernetes Deployment YAML ### Reference Kubernetes Deployment YAML
Reference Kubernetes `deployment.yaml` that was load tested by us Reference Kubernetes `deployment.yaml` that was load tested by us

View file

@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
### Step 1. Setup Proxy ### Step 1. Setup Proxy
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`. - `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
```bash ```bash
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks" export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"

View file

@ -12,8 +12,8 @@ Requirements:
You can set budgets at 3 levels: You can set budgets at 3 levels:
- For the proxy - For the proxy
- For a user - For an internal user
- For a 'user' passed to `/chat/completions`, `/embeddings` etc - For an end-user
- For a key - For a key
- For a key (model specific budgets) - For a key (model specific budgets)
@ -58,7 +58,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
}' }'
``` ```
</TabItem> </TabItem>
<TabItem value="per-user" label="For User"> <TabItem value="per-user" label="For Internal User">
Apply a budget across multiple keys. Apply a budget across multiple keys.
@ -165,12 +165,12 @@ curl --location 'http://localhost:4000/team/new' \
} }
``` ```
</TabItem> </TabItem>
<TabItem value="per-user-chat" label="For 'user' passed to /chat/completions"> <TabItem value="per-user-chat" label="For End User">
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user** Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
**Step 1. Modify config.yaml** **Step 1. Modify config.yaml**
Define `litellm.max_user_budget` Define `litellm.max_end_user_budget`
```yaml ```yaml
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
@ -328,7 +328,7 @@ You can set:
- max parallel requests - max parallel requests
<Tabs> <Tabs>
<TabItem value="per-user" label="Per User"> <TabItem value="per-user" label="Per Internal User">
Use `/user/new`, to persist rate limits across multiple keys. Use `/user/new`, to persist rate limits across multiple keys.
@ -408,7 +408,7 @@ curl --location 'http://localhost:4000/user/new' \
``` ```
## Create new keys for existing user ## Create new keys for existing internal user
Just include user_id in the `/key/generate` request. Just include user_id in the `/key/generate` request.

View file

@ -96,7 +96,7 @@ print(response)
- `router.aimage_generation()` - async image generation calls - `router.aimage_generation()` - async image generation calls
## Advanced - Routing Strategies ## Advanced - Routing Strategies
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based #### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based, Cost Based
Router provides 4 strategies for routing your calls across multiple deployments: Router provides 4 strategies for routing your calls across multiple deployments:
@ -467,6 +467,101 @@ async def router_acompletion():
asyncio.run(router_acompletion()) asyncio.run(router_acompletion())
``` ```
</TabItem>
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
Picks a deployment based on the lowest cost
How this works:
- Get all healthy deployments
- Select all deployments that are under their provided `rpm/tpm` limits
- For each deployment check if `litellm_param["model"]` exists in [`litellm_model_cost_map`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
- if deployment does not exist in `litellm_model_cost_map` -> use deployment_cost= `$1`
- Select deployment with lowest cost
```python
from litellm import Router
import asyncio
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-4"},
"model_info": {"id": "openai-gpt-4"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "groq/llama3-8b-8192"},
"model_info": {"id": "groq-llama"},
},
]
# init router
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
async def router_acompletion():
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}]
)
print(response)
print(response._hidden_params["model_id"]) # expect groq-llama, since groq/llama has lowest cost
return response
asyncio.run(router_acompletion())
```
#### Using Custom Input/Output pricing
Set `litellm_params["input_cost_per_token"]` and `litellm_params["output_cost_per_token"]` for using custom pricing when routing
```python
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00003,
},
"model_info": {"id": "chatgpt-v-experimental"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-1",
"input_cost_per_token": 0.000000001,
"output_cost_per_token": 0.00000001,
},
"model_info": {"id": "chatgpt-v-1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-5",
"input_cost_per_token": 10,
"output_cost_per_token": 12,
},
"model_info": {"id": "chatgpt-v-5"},
},
]
# init router
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
async def router_acompletion():
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}]
)
print(response)
print(response._hidden_params["model_id"]) # expect chatgpt-v-1, since chatgpt-v-1 has lowest cost
return response
asyncio.run(router_acompletion())
```
</TabItem> </TabItem>
</Tabs> </Tabs>
@ -616,6 +711,57 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}") print(f"response: {response}")
``` ```
#### Retries based on Error Type
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
Example:
- 4 retries for `ContentPolicyViolationError`
- 0 retries for `RateLimitErrors`
Example Usage
```python
from litellm.router import RetryPolicy
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
BadRequestErrorRetries=1,
TimeoutErrorRetries=2,
RateLimitErrorRetries=3,
)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
{
"model_name": "bad-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
},
],
retry_policy=retry_policy,
)
response = await router.acompletion(
model=model,
messages=messages,
)
```
### Fallbacks ### Fallbacks
If a call fails after num_retries, fall back to another model group. If a call fails after num_retries, fall back to another model group.
@ -940,6 +1086,46 @@ async def test_acompletion_caching_on_router_caching_groups():
asyncio.run(test_acompletion_caching_on_router_caching_groups()) asyncio.run(test_acompletion_caching_on_router_caching_groups())
``` ```
## Alerting 🚨
Send alerts to slack / your webhook url for the following events
- LLM API Exceptions
- Slow LLM Responses
Get a slack webhook url from https://api.slack.com/messaging/webhooks
#### Usage
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
```python
from litellm.router import AlertingConfig
import litellm
import os
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "bad_key",
},
}
],
alerting_config= AlertingConfig(
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
),
)
try:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
except:
pass
```
## Track cost for Azure Deployments ## Track cost for Azure Deployments
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking **Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
@ -1108,6 +1294,7 @@ def __init__(
"least-busy", "least-busy",
"usage-based-routing", "usage-based-routing",
"latency-based-routing", "latency-based-routing",
"cost-based-routing",
] = "simple-shuffle", ] = "simple-shuffle",
## DEBUGGING ## ## DEBUGGING ##

View file

@ -50,6 +50,7 @@ const sidebars = {
items: ["proxy/logging", "proxy/streaming_logging"], items: ["proxy/logging", "proxy/streaming_logging"],
}, },
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/customer_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking", "proxy/cost_tracking",
"proxy/token_auth", "proxy/token_auth",
@ -131,9 +132,13 @@ const sidebars = {
"providers/cohere", "providers/cohere",
"providers/anyscale", "providers/anyscale",
"providers/huggingface", "providers/huggingface",
"providers/watsonx",
"providers/predibase",
"providers/triton-inference-server",
"providers/ollama", "providers/ollama",
"providers/perplexity", "providers/perplexity",
"providers/groq", "providers/groq",
"providers/deepseek",
"providers/fireworks_ai", "providers/fireworks_ai",
"providers/vllm", "providers/vllm",
"providers/xinference", "providers/xinference",
@ -149,7 +154,7 @@ const sidebars = {
"providers/openrouter", "providers/openrouter",
"providers/custom_openai_proxy", "providers/custom_openai_proxy",
"providers/petals", "providers/petals",
"providers/watsonx",
], ],
}, },
"proxy/custom_pricing", "proxy/custom_pricing",

View file

@ -291,7 +291,7 @@ def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
def _forecast_daily_cost(data: list): def _forecast_daily_cost(data: list):
import requests import requests # type: ignore
from datetime import datetime, timedelta from datetime import datetime, timedelta
if len(data) == 0: if len(data) == 0:

108
index.yaml Normal file
View file

@ -0,0 +1,108 @@
apiVersion: v1
entries:
litellm-helm:
- apiVersion: v2
appVersion: v1.35.38
created: "2024-05-06T10:22:24.384392-07:00"
dependencies:
- condition: db.deployStandalone
name: postgresql
repository: oci://registry-1.docker.io/bitnamicharts
version: '>=13.3.0'
- condition: redis.enabled
name: redis
repository: oci://registry-1.docker.io/bitnamicharts
version: '>=18.0.0'
description: Call all LLM APIs using the OpenAI format
digest: 60f0cfe9e7c1087437cb35f6fb7c43c3ab2be557b6d3aec8295381eb0dfa760f
name: litellm-helm
type: application
urls:
- litellm-helm-0.2.0.tgz
version: 0.2.0
postgresql:
- annotations:
category: Database
images: |
- name: os-shell
image: docker.io/bitnami/os-shell:12-debian-12-r16
- name: postgres-exporter
image: docker.io/bitnami/postgres-exporter:0.15.0-debian-12-r14
- name: postgresql
image: docker.io/bitnami/postgresql:16.2.0-debian-12-r6
licenses: Apache-2.0
apiVersion: v2
appVersion: 16.2.0
created: "2024-05-06T10:22:24.387717-07:00"
dependencies:
- name: common
repository: oci://registry-1.docker.io/bitnamicharts
tags:
- bitnami-common
version: 2.x.x
description: PostgreSQL (Postgres) is an open source object-relational database
known for reliability and data integrity. ACID-compliant, it supports foreign
keys, joins, views, triggers and stored procedures.
digest: 3c8125526b06833df32e2f626db34aeaedb29d38f03d15349db6604027d4a167
home: https://bitnami.com
icon: https://bitnami.com/assets/stacks/postgresql/img/postgresql-stack-220x234.png
keywords:
- postgresql
- postgres
- database
- sql
- replication
- cluster
maintainers:
- name: VMware, Inc.
url: https://github.com/bitnami/charts
name: postgresql
sources:
- https://github.com/bitnami/charts/tree/main/bitnami/postgresql
urls:
- charts/postgresql-14.3.1.tgz
version: 14.3.1
redis:
- annotations:
category: Database
images: |
- name: kubectl
image: docker.io/bitnami/kubectl:1.29.2-debian-12-r3
- name: os-shell
image: docker.io/bitnami/os-shell:12-debian-12-r16
- name: redis
image: docker.io/bitnami/redis:7.2.4-debian-12-r9
- name: redis-exporter
image: docker.io/bitnami/redis-exporter:1.58.0-debian-12-r4
- name: redis-sentinel
image: docker.io/bitnami/redis-sentinel:7.2.4-debian-12-r7
licenses: Apache-2.0
apiVersion: v2
appVersion: 7.2.4
created: "2024-05-06T10:22:24.391903-07:00"
dependencies:
- name: common
repository: oci://registry-1.docker.io/bitnamicharts
tags:
- bitnami-common
version: 2.x.x
description: Redis(R) is an open source, advanced key-value store. It is often
referred to as a data structure server since keys can contain strings, hashes,
lists, sets and sorted sets.
digest: b2fa1835f673a18002ca864c54fadac3c33789b26f6c5e58e2851b0b14a8f984
home: https://bitnami.com
icon: https://bitnami.com/assets/stacks/redis/img/redis-stack-220x234.png
keywords:
- redis
- keyvalue
- database
maintainers:
- name: VMware, Inc.
url: https://github.com/bitnami/charts
name: redis
sources:
- https://github.com/bitnami/charts/tree/main/bitnami/redis
urls:
- charts/redis-18.19.1.tgz
version: 18.19.1
generated: "2024-05-06T10:22:24.375026-07:00"

BIN
litellm-helm-0.2.0.tgz Normal file

Binary file not shown.

View file

@ -1,3 +1,7 @@
### Hide pydantic namespace conflict warnings globally ###
import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
@ -71,9 +75,11 @@ maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None openrouter_key: Optional[str] = None
predibase_key: Optional[str] = None
huggingface_key: Optional[str] = None huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None vertex_project: Optional[str] = None
vertex_location: Optional[str] = None vertex_location: Optional[str] = None
predibase_tenant_id: Optional[str] = None
togetherai_api_key: Optional[str] = None togetherai_api_key: Optional[str] = None
cloudflare_api_key: Optional[str] = None cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None baseten_key: Optional[str] = None
@ -361,6 +367,7 @@ openai_compatible_endpoints: List = [
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1", "api.mistral.ai/v1",
"api.groq.com/openai/v1", "api.groq.com/openai/v1",
"api.deepseek.com/v1",
"api.together.xyz/v1", "api.together.xyz/v1",
] ]
@ -369,6 +376,7 @@ openai_compatible_providers: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"deepseek",
"deepinfra", "deepinfra",
"perplexity", "perplexity",
"xinference", "xinference",
@ -523,12 +531,15 @@ provider_list: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"deepseek",
"maritalk", "maritalk",
"voyage", "voyage",
"cloudflare", "cloudflare",
"xinference", "xinference",
"fireworks_ai", "fireworks_ai",
"watsonx", "watsonx",
"triton",
"predibase",
"custom", # custom apis "custom", # custom apis
] ]
@ -605,7 +616,6 @@ all_embedding_models = (
####### IMAGE GENERATION MODELS ################### ####### IMAGE GENERATION MODELS ###################
openai_image_generation_models = ["dall-e-2", "dall-e-3"] openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout from .timeout import timeout
from .utils import ( from .utils import (
client, client,
@ -613,6 +623,8 @@ from .utils import (
get_optional_params, get_optional_params,
modify_integration, modify_integration,
token_counter, token_counter,
create_pretrained_tokenizer,
create_tokenizer,
cost_per_token, cost_per_token,
completion_cost, completion_cost,
supports_function_calling, supports_function_calling,
@ -636,9 +648,11 @@ from .utils import (
get_secret, get_secret,
get_supported_openai_params, get_supported_openai_params,
get_api_base, get_api_base,
get_first_chars_messages,
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
from .llms.predibase import PredibaseConfig
from .llms.anthropic_text import AnthropicTextConfig from .llms.anthropic_text import AnthropicTextConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere import CohereConfig from .llms.cohere import CohereConfig
@ -692,3 +706,4 @@ from .exceptions import (
from .budget_manager import BudgetManager from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server
from .router import Router from .router import Router
from .assistants.main import *

View file

@ -10,8 +10,8 @@
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os import os
import inspect import inspect
import redis, litellm import redis, litellm # type: ignore
import redis.asyncio as async_redis import redis.asyncio as async_redis # type: ignore
from typing import List, Optional from typing import List, Optional

495
litellm/assistants/main.py Normal file
View file

@ -0,0 +1,495 @@
# What is this?
## Main file for assistants API logic
from typing import Iterable
import os
import litellm
from openai import OpenAI
from litellm import client
from litellm.utils import supports_httpx_timeout
from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
####### ENVIRONMENT VARIABLES ###################
openai_assistants_api = OpenAIAssistantsAPI()
### ASSISTANTS ###
def get_assistants(
custom_llm_provider: Literal["openai"],
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[Assistant]:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[SyncCursorPage[Assistant]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_assistants(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### THREADS ###
def create_thread(
custom_llm_provider: Literal["openai"],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
metadata: Optional[dict] = None,
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> Thread:
"""
- get the llm provider
- if openai - route it there
- pass through relevant params
```
from litellm import create_thread
create_thread(
custom_llm_provider="openai",
### OPTIONAL ###
messages = {
"role": "user",
"content": "Hello, what is AI?"
},
{
"role": "user",
"content": "How does AI work? Explain it in simple terms."
}]
)
```
"""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.create_thread(
messages=messages,
metadata=metadata,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
def get_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[OpenAI] = None,
**kwargs,
) -> Thread:
"""Get the thread object, given a thread_id"""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_thread(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### MESSAGES ###
def add_message(
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> OpenAIMessage:
### COMMON OBJECTS ###
message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[OpenAIMessage] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.add_message(
thread_id=thread_id,
message_data=message_data,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
def get_messages(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[OpenAIMessage]:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[SyncCursorPage[OpenAIMessage]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.get_messages(
thread_id=thread_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
### RUNS ###
def run_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[OpenAI] = None,
**kwargs,
) -> Run:
"""Run a given thread + assistant."""
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Run] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
client=client,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response

View file

@ -10,7 +10,7 @@
import os, json, time import os, json, time
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
import requests, threading import requests, threading # type: ignore
from typing import Optional, Union, Literal from typing import Optional, Union, Literal

View file

@ -106,7 +106,7 @@ class InMemoryCache(BaseCache):
return_val.append(val) return_val.append(val)
return return_val return return_val
async def async_increment(self, key, value: int, **kwargs) -> int: async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value # get the value
init_value = await self.async_get_cache(key=key) or 0 init_value = await self.async_get_cache(key=key) or 0
value = init_value + value value = init_value + value
@ -177,11 +177,18 @@ class RedisCache(BaseCache):
try: try:
# asyncio.get_running_loop().create_task(self.ping()) # asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping()) result = asyncio.get_running_loop().create_task(self.ping())
except Exception: except Exception as e:
pass verbose_logger.error(
"Error connecting to Async Redis client", extra={"error": str(e)}
)
### SYNC HEALTH PING ### ### SYNC HEALTH PING ###
try:
self.redis_client.ping() self.redis_client.ping()
except Exception as e:
verbose_logger.error(
"Error connecting to Sync Redis client", extra={"error": str(e)}
)
def init_async_client(self): def init_async_client(self):
from ._redis import get_redis_async_client from ._redis import get_redis_async_client
@ -416,12 +423,12 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() # logging done in here await self.flush_cache_buffer() # logging done in here
async def async_increment(self, key, value: int, **kwargs) -> int: async def async_increment(self, key, value: float, **kwargs) -> float:
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
start_time = time.time() start_time = time.time()
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
result = await redis_client.incr(name=key, amount=value) result = await redis_client.incrbyfloat(name=key, amount=value)
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()
_duration = end_time - start_time _duration = end_time - start_time
@ -1375,18 +1382,41 @@ class DualCache(BaseCache):
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
async def async_batch_set_cache(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""
Batch write values to the cache
"""
print_verbose(
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.get("ttl", None)
)
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
async def async_increment_cache( async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs self, key, value: float, local_only: bool = False, **kwargs
) -> int: ) -> float:
""" """
Key - the key in cache Key - the key in cache
Value - int - the value you want to increment by Value - float - the value you want to increment by
Returns - int - the incremented value Returns - float - the incremented value
""" """
try: try:
result: int = value result: float = value
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment( result = await self.in_memory_cache.async_increment(
key, value, **kwargs key, value, **kwargs

View file

@ -1,7 +1,6 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -4,18 +4,30 @@ import datetime
class AthinaLogger: class AthinaLogger:
def __init__(self): def __init__(self):
import os import os
self.athina_api_key = os.getenv("ATHINA_API_KEY") self.athina_api_key = os.getenv("ATHINA_API_KEY")
self.headers = { self.headers = {
"athina-api-key": self.athina_api_key, "athina-api-key": self.athina_api_key,
"Content-Type": "application/json" "Content-Type": "application/json",
} }
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference" self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
self.additional_keys = ["environment", "prompt_slug", "customer_id", "customer_user_id", "session_id", "external_reference_id", "context", "expected_response", "user_query"] self.additional_keys = [
"environment",
"prompt_slug",
"customer_id",
"customer_user_id",
"session_id",
"external_reference_id",
"context",
"expected_response",
"user_query",
]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
import requests import requests # type: ignore
import json import json
import traceback import traceback
try: try:
response_json = response_obj.model_dump() if response_obj else {} response_json = response_obj.model_dump() if response_obj else {}
data = { data = {
@ -23,19 +35,30 @@ class AthinaLogger:
"request": kwargs, "request": kwargs,
"response": response_json, "response": response_json,
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"), "prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
"completion_tokens": response_json.get("usage", {}).get("completion_tokens"), "completion_tokens": response_json.get("usage", {}).get(
"completion_tokens"
),
"total_tokens": response_json.get("usage", {}).get("total_tokens"), "total_tokens": response_json.get("usage", {}).get("total_tokens"),
} }
if type(end_time) == datetime.datetime and type(start_time) == datetime.datetime: if (
data["response_time"] = int((end_time - start_time).total_seconds() * 1000) type(end_time) == datetime.datetime
and type(start_time) == datetime.datetime
):
data["response_time"] = int(
(end_time - start_time).total_seconds() * 1000
)
if "messages" in kwargs: if "messages" in kwargs:
data["prompt"] = kwargs.get("messages", None) data["prompt"] = kwargs.get("messages", None)
# Directly add tools or functions if present # Directly add tools or functions if present
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
data.update((k, v) for k, v in optional_params.items() if k in ["tools", "functions"]) data.update(
(k, v)
for k, v in optional_params.items()
if k in ["tools", "functions"]
)
# Add additional metadata keys # Add additional metadata keys
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) metadata = kwargs.get("litellm_params", {}).get("metadata", {})
@ -44,11 +67,19 @@ class AthinaLogger:
if key in metadata: if key in metadata:
data[key] = metadata[key] data[key] = metadata[key]
response = requests.post(self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str)) response = requests.post(
self.athina_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200: if response.status_code != 200:
print_verbose(f"Athina Logger Error - {response.text}, {response.status_code}") print_verbose(
f"Athina Logger Error - {response.text}, {response.status_code}"
)
else: else:
print_verbose(f"Athina Logger Succeeded - {response.text}") print_verbose(f"Athina Logger Succeeded - {response.text}")
except Exception as e: except Exception as e:
print_verbose(f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}") print_verbose(
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass pass

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,7 +3,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -1,7 +1,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,15 +1,17 @@
import requests import requests # type: ignore
import json import json
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
class GreenscaleLogger: class GreenscaleLogger:
def __init__(self): def __init__(self):
import os import os
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY") self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
self.headers = { self.headers = {
"api-key": self.greenscale_api_key, "api-key": self.greenscale_api_key,
"Content-Type": "application/json" "Content-Type": "application/json",
} }
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT") self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
@ -19,13 +21,18 @@ class GreenscaleLogger:
data = { data = {
"modelId": kwargs.get("model"), "modelId": kwargs.get("model"),
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"), "inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"), "outputTokenCount": response_json.get("usage", {}).get(
"completion_tokens"
),
} }
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') data["timestamp"] = datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
if type(end_time) == datetime and type(start_time) == datetime: if type(end_time) == datetime and type(start_time) == datetime:
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000) data["invocationLatency"] = int(
(end_time - start_time).total_seconds() * 1000
)
# Add additional metadata keys to tags # Add additional metadata keys to tags
tags = [] tags = []
@ -37,15 +44,25 @@ class GreenscaleLogger:
elif key == "greenscale_application": elif key == "greenscale_application":
data["application"] = value data["application"] = value
else: else:
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)}) tags.append(
{"key": key.replace("greenscale_", ""), "value": str(value)}
)
data["tags"] = tags data["tags"] = tags
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str)) response = requests.post(
self.greenscale_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200: if response.status_code != 200:
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}") print_verbose(
f"Greenscale Logger Error - {response.text}, {response.status_code}"
)
else: else:
print_verbose(f"Greenscale Logger Succeeded - {response.text}") print_verbose(f"Greenscale Logger Succeeded - {response.text}")
except Exception as e: except Exception as e:
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}") print_verbose(
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass pass

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# On success, logs events to Helicone # On success, logs events to Helicone
import dotenv, os import dotenv, os
import requests import requests # type: ignore
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv

View file

@ -262,6 +262,23 @@ class LangFuseLogger:
try: try:
tags = [] tags = []
try:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, dict)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
metadata = new_metadata
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3") supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3") supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3") supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
@ -272,36 +289,9 @@ class LangFuseLogger:
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
if supports_tags: if supports_tags:
metadata_tags = metadata.get("tags", []) metadata_tags = metadata.pop("tags", [])
tags = metadata_tags tags = metadata_tags
trace_name = metadata.get("trace_name", None)
trace_id = metadata.get("trace_id", None)
existing_trace_id = metadata.get("existing_trace_id", None)
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if existing_trace_id is not None:
trace_params = {"id": existing_trace_id}
else: # don't overwrite an existing trace
trace_params = {
"name": trace_name,
"input": input,
"user_id": metadata.get("trace_user_id", user_id),
"id": trace_id,
"session_id": metadata.get("session_id", None),
}
if level == "ERROR":
trace_params["status_message"] = output
else:
trace_params["output"] = output
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
# Clean Metadata before logging - never log raw metadata # Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion # the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging # we clean out all extra litellm metadata params before logging
@ -328,6 +318,67 @@ class LangFuseLogger:
else: else:
clean_metadata[key] = value clean_metadata[key] = value
session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None)
trace_id = clean_metadata.pop("trace_id", None)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if existing_trace_id is not None:
trace_params = {"id": existing_trace_id}
# Update the following keys for this trace
for metadata_param_key in update_trace_keys:
trace_param_key = metadata_param_key.replace("trace_", "")
if trace_param_key not in trace_params:
updated_trace_value = clean_metadata.pop(
metadata_param_key, None
)
if updated_trace_value is not None:
trace_params[trace_param_key] = updated_trace_value
# Pop the trace specific keys that would have been popped if there were a new trace
for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
):
clean_metadata.pop(key, None)
# Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys:
trace_params["input"] = input
if "output" in update_trace_keys:
trace_params["output"] = output
else: # don't overwrite an existing trace
trace_params = {
"id": trace_id,
"name": trace_name,
"session_id": session_id,
"input": input,
"version": clean_metadata.pop(
"trace_version", clean_metadata.get("version", None)
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
"user_id": user_id,
}
for key in list(
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
):
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
key, None
)
if level == "ERROR":
trace_params["status_message"] = output
else:
trace_params["output"] = output
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
if ( if (
litellm._langfuse_default_tags is not None litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list) and isinstance(litellm._langfuse_default_tags, list)
@ -387,7 +438,7 @@ class LangFuseLogger:
"completion_tokens": response_obj["usage"]["completion_tokens"], "completion_tokens": response_obj["usage"]["completion_tokens"],
"total_cost": cost if supports_costs else None, "total_cost": cost if supports_costs else None,
} }
generation_name = metadata.get("generation_name", None) generation_name = clean_metadata.pop("generation_name", None)
if generation_name is None: if generation_name is None:
# just log `litellm-{call_type}` as the generation name # just log `litellm-{call_type}` as the generation name
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}" generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
@ -402,7 +453,7 @@ class LangFuseLogger:
generation_params = { generation_params = {
"name": generation_name, "name": generation_name,
"id": metadata.get("generation_id", generation_id), "id": clean_metadata.pop("generation_id", generation_id),
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
"model": kwargs["model"], "model": kwargs["model"],
@ -412,10 +463,11 @@ class LangFuseLogger:
"usage": usage, "usage": usage,
"metadata": clean_metadata, "metadata": clean_metadata,
"level": level, "level": level,
"version": clean_metadata.pop("version", None),
} }
if supports_prompt: if supports_prompt:
generation_params["prompt"] = metadata.get("prompt", None) generation_params["prompt"] = clean_metadata.pop("prompt", None)
if output is not None and isinstance(output, str) and level == "ERROR": if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output generation_params["status_message"] = output

View file

@ -1,15 +1,14 @@
#### What this does #### #### What this does ####
# On success, logs events to Langsmith # On success, logs events to Langsmith
import dotenv, os import dotenv, os # type: ignore
import requests import requests # type: ignore
import requests
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio import asyncio
import types import types
from pydantic import BaseModel from pydantic import BaseModel # type: ignore
def is_serializable(value): def is_serializable(value):
@ -79,8 +78,6 @@ class LangsmithLogger:
except: except:
response_obj = response_obj.dict() # type: ignore response_obj = response_obj.dict() # type: ignore
print(f"response_obj: {response_obj}")
data = { data = {
"name": run_name, "name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
@ -90,7 +87,6 @@ class LangsmithLogger:
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
} }
print(f"data: {data}")
response = requests.post( response = requests.post(
"https://api.smith.langchain.com/runs", "https://api.smith.langchain.com/runs",

View file

@ -4,7 +4,6 @@ from datetime import datetime, timezone
import traceback import traceback
import dotenv import dotenv
import importlib import importlib
import sys
import packaging import packaging
@ -18,13 +17,33 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_tool_calls(tool_calls):
if tool_calls is None:
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}
return serialized
return [clean_tool_call(tool_call) for tool_call in tool_calls]
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
def clean_message(message): def clean_message(message):
# if is strin, return as is # if is string, return as is
if isinstance(message, str): if isinstance(message, str):
return message return message
@ -38,9 +57,7 @@ def parse_messages(input):
# Only add tool_calls and function_call to res if they are set # Only add tool_calls and function_call to res if they are set
if message.get("tool_calls"): if message.get("tool_calls"):
serialized["tool_calls"] = message.get("tool_calls") serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
if message.get("function_call"):
serialized["function_call"] = message.get("function_call")
return serialized return serialized
@ -93,8 +110,13 @@ class LunaryLogger:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
optional_params = kwargs.get("optional_params", {})
metadata = litellm_params.get("metadata", {}) or {} metadata = litellm_params.get("metadata", {}) or {}
if optional_params:
# merge into extra
extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or [] tags = litellm_params.pop("tags", None) or []
if extra: if extra:
@ -104,7 +126,7 @@ class LunaryLogger:
# keep only serializable types # keep only serializable types
for param, value in extra.items(): for param, value in extra.items():
if not isinstance(value, (str, int, bool, float)): if not isinstance(value, (str, int, bool, float)) and param != "tools":
try: try:
extra[param] = str(value) extra[param] = str(value)
except: except:
@ -140,7 +162,7 @@ class LunaryLogger:
metadata=metadata, metadata=metadata,
runtime="litellm", runtime="litellm",
tags=tags, tags=tags,
extra=extra, params=extra,
) )
self.lunary_client.track_event( self.lunary_client.track_event(

View file

@ -2,7 +2,6 @@
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268 ## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
import dotenv, os, json import dotenv, os, json
import requests
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -38,7 +37,7 @@ class OpenMeterLogger(CustomLogger):
in the environment in the environment
""" """
missing_keys = [] missing_keys = []
if litellm.get_secret("OPENMETER_API_KEY", None) is None: if os.getenv("OPENMETER_API_KEY", None) is None:
missing_keys.append("OPENMETER_API_KEY") missing_keys.append("OPENMETER_API_KEY")
if len(missing_keys) > 0: if len(missing_keys) > 0:
@ -60,47 +59,56 @@ class OpenMeterLogger(CustomLogger):
"total_tokens": response_obj["usage"].get("total_tokens"), "total_tokens": response_obj["usage"].get("total_tokens"),
} }
subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
if not subject:
raise Exception("OpenMeter: user is required")
return { return {
"specversion": "1.0", "specversion": "1.0",
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"), "type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
"id": call_id, "id": call_id,
"time": dt, "time": dt,
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param "subject": subject,
"source": "litellm-proxy", "source": "litellm-proxy",
"data": {"model": model, "cost": cost, **usage}, "data": {"model": model, "cost": cost, **usage},
} }
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = litellm.get_secret( _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
)
if _url.endswith("/"): if _url.endswith("/"):
_url += "api/v1/events" _url += "api/v1/events"
else: else:
_url += "/api/v1/events" _url += "/api/v1/events"
api_key = litellm.get_secret("OPENMETER_API_KEY") api_key = os.getenv("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj) _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
self.sync_http_handler.post( _headers = {
url=_url,
data=_data,
headers={
"Content-Type": "application/cloudevents+json", "Content-Type": "application/cloudevents+json",
"Authorization": "Bearer {}".format(api_key), "Authorization": "Bearer {}".format(api_key),
}, }
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
) )
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = litellm.get_secret( _url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
)
if _url.endswith("/"): if _url.endswith("/"):
_url += "api/v1/events" _url += "api/v1/events"
else: else:
_url += "/api/v1/events" _url += "/api/v1/events"
api_key = litellm.get_secret("OPENMETER_API_KEY") api_key = os.getenv("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj) _data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = { _headers = {
@ -117,7 +125,6 @@ class OpenMeterLogger(CustomLogger):
response.raise_for_status() response.raise_for_status()
except Exception as e: except Exception as e:
print(f"\nAn Exception Occurred - {str(e)}")
if hasattr(response, "text"): if hasattr(response, "text"):
print(f"\nError Message: {response.text}") litellm.print_verbose(f"\nError Message: {response.text}")
raise e raise e

View file

@ -3,7 +3,7 @@
# On success, log events to Prometheus # On success, log events to Prometheus
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -19,7 +19,6 @@ class PrometheusLogger:
**kwargs, **kwargs,
): ):
try: try:
print(f"in init prometheus metrics")
from prometheus_client import Counter from prometheus_client import Counter
self.litellm_llm_api_failed_requests_metric = Counter( self.litellm_llm_api_failed_requests_metric = Counter(

View file

@ -4,7 +4,7 @@
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -183,7 +183,6 @@ class PrometheusServicesLogger:
) )
async def async_service_failure_hook(self, payload: ServiceLoggerPayload): async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing: if self.mock_testing:
self.mock_testing_failure_calls += 1 self.mock_testing_failure_calls += 1

View file

@ -1,12 +1,13 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class PromptLayerLogger: class PromptLayerLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
@ -32,7 +33,11 @@ class PromptLayerLogger:
tags = kwargs["litellm_params"]["metadata"]["pl_tags"] tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
# Remove "pl_tags" from metadata # Remove "pl_tags" from metadata
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"} metadata = {
k: v
for k, v in kwargs["litellm_params"]["metadata"].items()
if k != "pl_tags"
}
print_verbose( print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"

View file

@ -2,7 +2,6 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,25 +1,82 @@
#### What this does #### #### What this does ####
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import copy
import traceback
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict from typing import List, Literal, Any, Union, Optional, Dict
from litellm.caching import DualCache from litellm.caching import DualCache
import asyncio import asyncio
import aiohttp import aiohttp
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import datetime import datetime
from pydantic import BaseModel
from enum import Enum
from datetime import datetime as dt, timedelta
from litellm.integrations.custom_logger import CustomLogger
import random
class SlackAlerting: class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class SlackAlertingArgs(LiteLLMBase):
daily_report_frequency: int = 12 * 60 * 60 # 12 hours
report_check_interval: int = 5 * 60 # 5 minutes
class DeploymentMetrics(LiteLLMBase):
"""
Metrics per deployment, stored in cache
Used for daily reporting
"""
id: str
"""id of deployment in router model list"""
failed_request: bool
"""did it fail the request?"""
latency_per_output_token: Optional[float]
"""latency/output token of deployment"""
updated_at: dt
"""Current time of deployment being updated"""
class SlackAlertingCacheKeys(Enum):
"""
Enum for deployment daily metrics keys - {deployment_id}:{enum}
"""
failed_requests_key = "failed_requests_daily_metrics"
latency_key = "latency_daily_metrics"
report_sent_key = "daily_metrics_report_sent"
class SlackAlerting(CustomLogger):
"""
Class for sending Slack Alerts
"""
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
self, self,
alerting_threshold: float = 300, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: Optional[ alert_types: Optional[
List[ List[
@ -29,6 +86,7 @@ class SlackAlerting:
"llm_requests_hanging", "llm_requests_hanging",
"budget_alerts", "budget_alerts",
"db_exceptions", "db_exceptions",
"daily_reports",
] ]
] ]
] = [ ] = [
@ -37,31 +95,23 @@ class SlackAlerting:
"llm_requests_hanging", "llm_requests_hanging",
"budget_alerts", "budget_alerts",
"db_exceptions", "db_exceptions",
"daily_reports",
], ],
alert_to_webhook_url: Optional[ alert_to_webhook_url: Optional[
Dict Dict
] = None, # if user wants to separate alerts to diff channels ] = None, # if user wants to separate alerts to diff channels
alerting_args={},
default_webhook_url: Optional[str] = None,
): ):
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
self.alerting = alerting self.alerting = alerting
self.alert_types = alert_types self.alert_types = alert_types
self.internal_usage_cache = DualCache() self.internal_usage_cache = internal_usage_cache or DualCache()
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = AsyncHTTPHandler()
self.alert_to_webhook_url = alert_to_webhook_url self.alert_to_webhook_url = alert_to_webhook_url
self.langfuse_logger = None self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args)
try: self.default_webhook_url = default_webhook_url
from litellm.integrations.langfuse import LangFuseLogger
self.langfuse_logger = LangFuseLogger(
os.getenv("LANGFUSE_PUBLIC_KEY"),
os.getenv("LANGFUSE_SECRET_KEY"),
flush_interval=1,
)
except:
pass
pass
def update_values( def update_values(
self, self,
@ -69,6 +119,7 @@ class SlackAlerting:
alerting_threshold: Optional[float] = None, alerting_threshold: Optional[float] = None,
alert_types: Optional[List] = None, alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None, alert_to_webhook_url: Optional[Dict] = None,
alerting_args: Optional[Dict] = None,
): ):
if alerting is not None: if alerting is not None:
self.alerting = alerting self.alerting = alerting
@ -76,7 +127,8 @@ class SlackAlerting:
self.alerting_threshold = alerting_threshold self.alerting_threshold = alerting_threshold
if alert_types is not None: if alert_types is not None:
self.alert_types = alert_types self.alert_types = alert_types
if alerting_args is not None:
self.alerting_args = SlackAlertingArgs(**alerting_args)
if alert_to_webhook_url is not None: if alert_to_webhook_url is not None:
# update the dict # update the dict
if self.alert_to_webhook_url is None: if self.alert_to_webhook_url is None:
@ -103,72 +155,23 @@ class SlackAlerting:
def _add_langfuse_trace_id_to_alert( def _add_langfuse_trace_id_to_alert(
self, self,
request_info: str,
request_data: Optional[dict] = None, request_data: Optional[dict] = None,
kwargs: Optional[dict] = None, ) -> Optional[str]:
type: Literal["hanging_request", "slow_response"] = "hanging_request", """
start_time: Optional[datetime.datetime] = None, Returns langfuse trace url
end_time: Optional[datetime.datetime] = None, """
# do nothing for now
if (
request_data is not None
and request_data.get("metadata", {}).get("trace_id", None) is not None
): ):
import uuid trace_id = request_data["metadata"]["trace_id"]
if litellm.utils.langFuseLogger is not None:
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
return f"{base_url}/trace/{trace_id}"
return None
# For now: do nothing as we're debugging why this is not working as expected def _response_taking_too_long_callback_helper(
if request_data is not None:
trace_id = request_data.get("metadata", {}).get(
"trace_id", None
) # get langfuse trace id
if trace_id is None:
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
request_data["metadata"]["trace_id"] = trace_id
elif kwargs is not None:
_litellm_params = kwargs.get("litellm_params", {})
trace_id = _litellm_params.get("metadata", {}).get(
"trace_id", None
) # get langfuse trace id
if trace_id is None:
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
_litellm_params["metadata"]["trace_id"] = trace_id
# Log hanging request as an error on langfuse
if type == "hanging_request":
if self.langfuse_logger is not None:
_logging_kwargs = copy.deepcopy(request_data)
if _logging_kwargs is None:
_logging_kwargs = {}
_logging_kwargs["litellm_params"] = {}
request_data = request_data or {}
_logging_kwargs["litellm_params"]["metadata"] = request_data.get(
"metadata", {}
)
# log to langfuse in a separate thread
import threading
threading.Thread(
target=self.langfuse_logger.log_event,
args=(
_logging_kwargs,
None,
start_time,
end_time,
None,
print,
"ERROR",
"Requests is hanging",
),
).start()
_langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
_langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
# langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
_langfuse_url = (
f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
)
request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
return request_info
def _response_taking_too_long_callback(
self, self,
kwargs, # kwargs to completion kwargs, # kwargs to completion
start_time, start_time,
@ -233,7 +236,7 @@ class SlackAlerting:
return return
time_difference_float, model, api_base, messages = ( time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback( self._response_taking_too_long_callback_helper(
kwargs=kwargs, kwargs=kwargs,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
@ -242,10 +245,6 @@ class SlackAlerting:
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
if time_difference_float > self.alerting_threshold: if time_difference_float > self.alerting_threshold:
if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert(
request_info=request_info, kwargs=kwargs, type="slow_response"
)
# add deployment latencies to alert # add deployment latencies to alert
if ( if (
kwargs is not None kwargs is not None
@ -253,6 +252,9 @@ class SlackAlerting:
and "metadata" in kwargs["litellm_params"] and "metadata" in kwargs["litellm_params"]
): ):
_metadata = kwargs["litellm_params"]["metadata"] _metadata = kwargs["litellm_params"]["metadata"]
request_info = litellm.utils._add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata
)
_deployment_latency_map = self._get_deployment_latencies_to_alert( _deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=_metadata metadata=_metadata
@ -267,8 +269,178 @@ class SlackAlerting:
alert_type="llm_too_slow", alert_type="llm_too_slow",
) )
async def log_failure_event(self, original_exception: Exception): async def async_update_daily_reports(
pass self, deployment_metrics: DeploymentMetrics
) -> int:
"""
Store the perf by deployment in cache
- Number of failed requests per deployment
- Latency / output tokens per deployment
'deployment_id:daily_metrics:failed_requests'
'deployment_id:daily_metrics:latency_per_output_token'
Returns
int - count of metrics set (1 - if just latency, 2 - if failed + latency)
"""
return_val = 0
try:
## FAILED REQUESTS ##
if deployment_metrics.failed_request:
await self.internal_usage_cache.async_increment_cache(
key="{}:{}".format(
deployment_metrics.id,
SlackAlertingCacheKeys.failed_requests_key.value,
),
value=1,
)
return_val += 1
## LATENCY ##
if deployment_metrics.latency_per_output_token is not None:
await self.internal_usage_cache.async_increment_cache(
key="{}:{}".format(
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
),
value=deployment_metrics.latency_per_output_token,
)
return_val += 1
return return_val
except Exception as e:
return 0
async def send_daily_reports(self, router) -> bool:
"""
Send a daily report on:
- Top 5 deployments with most failed requests
- Top 5 slowest deployments (normalized by latency/output tokens)
Get the value from redis cache (if available) or in-memory and send it
Cleanup:
- reset values in cache -> prevent memory leak
Returns:
True -> if successfuly sent
False -> if not sent
"""
ids = router.get_model_ids()
# get keys
failed_request_keys = [
"{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
for id in ids
]
latency_keys = [
"{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
]
combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
keys=combined_metrics_keys
) # [1, 2, None, ..]
all_none = True
for val in combined_metrics_values:
if val is not None:
all_none = False
if all_none:
return False
failed_request_values = combined_metrics_values[
: len(failed_request_keys)
] # # [1, 2, None, ..]
latency_values = combined_metrics_values[len(failed_request_keys) :]
# find top 5 failed
## Replace None values with a placeholder value (-1 in this case)
placeholder_value = 0
replaced_failed_values = [
value if value is not None else placeholder_value
for value in failed_request_values
]
## Get the indices of top 5 keys with the highest numerical values (ignoring None values)
top_5_failed = sorted(
range(len(replaced_failed_values)),
key=lambda i: replaced_failed_values[i],
reverse=True,
)[:5]
# find top 5 slowest
# Replace None values with a placeholder value (-1 in this case)
placeholder_value = 0
replaced_slowest_values = [
value if value is not None else placeholder_value
for value in latency_values
]
# Get the indices of top 5 values with the highest numerical values (ignoring None values)
top_5_slowest = sorted(
range(len(replaced_slowest_values)),
key=lambda i: replaced_slowest_values[i],
reverse=True,
)[:5]
# format alert -> return the litellm model name + api base
message = f"\n\nHere are today's key metrics 📈: \n\n"
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n"
for i in range(len(top_5_failed)):
key = failed_request_keys[top_5_failed[i]].split(":")[0]
_deployment = router.get_model_info(key)
if isinstance(_deployment, dict):
deployment_name = _deployment["litellm_params"].get("model", "")
else:
return False
api_base = litellm.get_api_base(
model=deployment_name,
optional_params=(
_deployment["litellm_params"] if _deployment is not None else {}
),
)
if api_base is None:
api_base = ""
value = replaced_failed_values[top_5_failed[i]]
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n"
for i in range(len(top_5_slowest)):
key = latency_keys[top_5_slowest[i]].split(":")[0]
_deployment = router.get_model_info(key)
if _deployment is not None:
deployment_name = _deployment["litellm_params"].get("model", "")
else:
deployment_name = ""
api_base = litellm.get_api_base(
model=deployment_name,
optional_params=(
_deployment["litellm_params"] if _deployment is not None else {}
),
)
value = round(replaced_slowest_values[top_5_slowest[i]], 3)
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
# cache cleanup -> reset values to 0
latency_cache_keys = [(key, 0) for key in latency_keys]
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
await self.internal_usage_cache.async_batch_set_cache(
cache_list=combined_metrics_cache_keys
)
# send alert
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
return True
async def response_taking_too_long( async def response_taking_too_long(
self, self,
@ -326,6 +498,11 @@ class SlackAlerting:
# in that case we fallback to the api base set in the request metadata # in that case we fallback to the api base set in the request metadata
_metadata = request_data["metadata"] _metadata = request_data["metadata"]
_api_base = _metadata.get("api_base", "") _api_base = _metadata.get("api_base", "")
request_info = litellm.utils._add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata
)
if _api_base is None: if _api_base is None:
_api_base = "" _api_base = ""
request_info += f"\nAPI Base: `{_api_base}`" request_info += f"\nAPI Base: `{_api_base}`"
@ -335,14 +512,13 @@ class SlackAlerting:
) )
if "langfuse" in litellm.success_callback: if "langfuse" in litellm.success_callback:
request_info = self._add_langfuse_trace_id_to_alert( langfuse_url = self._add_langfuse_trace_id_to_alert(
request_info=request_info,
request_data=request_data, request_data=request_data,
type="hanging_request",
start_time=start_time,
end_time=end_time,
) )
if langfuse_url is not None:
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
# add deployment latencies to alert # add deployment latencies to alert
_deployment_latency_map = self._get_deployment_latencies_to_alert( _deployment_latency_map = self._get_deployment_latencies_to_alert(
metadata=request_data.get("metadata", {}) metadata=request_data.get("metadata", {})
@ -475,6 +651,53 @@ class SlackAlerting:
return return
async def model_added_alert(self, model_name: str, litellm_model_name: str):
model_info = litellm.model_cost.get(litellm_model_name, {})
model_info_str = ""
for k, v in model_info.items():
if k == "input_cost_per_token" or k == "output_cost_per_token":
# when converting to string it should not be 1.63e-06
v = "{:.8f}".format(v)
model_info_str += f"{k}: {v}\n"
message = f"""
*🚅 New Model Added*
Model Name: `{model_name}`
Usage OpenAI Python SDK:
```
import openai
client = openai.OpenAI(
api_key="your_api_key",
base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
)
response = client.chat.completions.create(
model="{model_name}", # model to send to the proxy
messages = [
{{
"role": "user",
"content": "this is a test request, write a short poem"
}}
]
)
```
Model Info:
```
{model_info_str}
```
"""
await self.send_alert(
message=message, level="Low", alert_type="new_model_added"
)
pass
async def model_removed_alert(self, model_name: str):
pass
async def send_alert( async def send_alert(
self, self,
message: str, message: str,
@ -485,7 +708,11 @@ class SlackAlerting:
"llm_requests_hanging", "llm_requests_hanging",
"budget_alerts", "budget_alerts",
"db_exceptions", "db_exceptions",
"daily_reports",
"new_model_added",
"cooldown_deployment",
], ],
**kwargs,
): ):
""" """
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
@ -510,9 +737,16 @@ class SlackAlerting:
# Get the current timestamp # Get the current timestamp
current_time = datetime.now().strftime("%H:%M:%S") current_time = datetime.now().strftime("%H:%M:%S")
_proxy_base_url = os.getenv("PROXY_BASE_URL", None) _proxy_base_url = os.getenv("PROXY_BASE_URL", None)
if alert_type == "daily_reports" or alert_type == "new_model_added":
formatted_message = message
else:
formatted_message = ( formatted_message = (
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}" f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
) )
if kwargs:
for key, value in kwargs.items():
formatted_message += f"\n\n{key}: `{value}`\n\n"
if _proxy_base_url is not None: if _proxy_base_url is not None:
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`" formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
@ -522,6 +756,8 @@ class SlackAlerting:
and alert_type in self.alert_to_webhook_url and alert_type in self.alert_to_webhook_url
): ):
slack_webhook_url = self.alert_to_webhook_url[alert_type] slack_webhook_url = self.alert_to_webhook_url[alert_type]
elif self.default_webhook_url is not None:
slack_webhook_url = self.default_webhook_url
else: else:
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
@ -539,3 +775,113 @@ class SlackAlerting:
pass pass
else: else:
print("Error sending slack alert. Error=", response.text) # noqa print("Error sending slack alert. Error=", response.text) # noqa
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Log deployment latency"""
if "daily_reports" in self.alert_types:
model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
response_s: timedelta = end_time - start_time
final_value = response_s
total_tokens = 0
if isinstance(response_obj, litellm.ModelResponse):
completion_tokens = response_obj.usage.completion_tokens
final_value = float(response_s.total_seconds() / completion_tokens)
await self.async_update_daily_reports(
DeploymentMetrics(
id=model_id,
failed_request=False,
latency_per_output_token=final_value,
updated_at=litellm.utils.get_utc_datetime(),
)
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency"""
if "daily_reports" in self.alert_types:
model_id = (
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
await self.async_update_daily_reports(
DeploymentMetrics(
id=model_id,
failed_request=True,
latency_per_output_token=None,
updated_at=litellm.utils.get_utc_datetime(),
)
)
if "llm_exceptions" in self.alert_types:
original_exception = kwargs.get("exception", None)
await self.send_alert(
message="LLM API Failure - " + str(original_exception),
level="High",
alert_type="llm_exceptions",
)
async def _run_scheduler_helper(self, llm_router) -> bool:
"""
Returns:
- True -> report sent
- False -> report not sent
"""
report_sent_bool = False
report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value
) # None | datetime
current_time = litellm.utils.get_utc_datetime()
if report_sent is None:
_current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
value=_current_time,
)
else:
# check if current time - interval >= time last sent
delta = current_time - timedelta(
seconds=self.alerting_args.daily_report_frequency
)
if isinstance(report_sent, str):
report_sent = dt.fromisoformat(report_sent)
if delta >= report_sent:
# Sneak in the reporting logic here
await self.send_daily_reports(router=llm_router)
# Also, don't forget to update the report_sent time after sending the report!
_current_time = current_time.isoformat()
await self.internal_usage_cache.async_set_cache(
key=SlackAlertingCacheKeys.report_sent_key.value,
value=_current_time,
)
report_sent_bool = True
return report_sent_bool
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
"""
If 'daily_reports' enabled
Ping redis cache every 5 minutes to check if we should send the report
If yes -> call send_daily_report()
"""
if llm_router is None or self.alert_types is None:
return
if "daily_reports" in self.alert_types:
while True:
await self._run_scheduler_helper(llm_router=llm_router)
interval = random.randint(
self.alerting_args.report_check_interval - 3,
self.alerting_args.report_check_interval + 3,
) # shuffle to prevent collisions
await asyncio.sleep(interval)
return

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,8 +1,8 @@
import os, types, traceback import os, types, traceback
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, httpx import time, httpx # type: ignore
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx import httpx # type: ignore
class AlephAlphaError(Exception): class AlephAlphaError(Exception):

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
@ -9,7 +9,7 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx import httpx # type: ignore
class AnthropicConstants(Enum): class AnthropicConstants(Enum):
@ -84,6 +84,51 @@ class AnthropicConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
if (
value == "\n"
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
# makes headers for API call # makes headers for API call
def validate_environment(api_key, user_headers): def validate_environment(api_key, user_headers):
@ -139,11 +184,6 @@ class AnthropicChatCompletion(BaseLLM):
message=str(completion_response["error"]), message=str(completion_response["error"]),
status_code=response.status_code, status_code=response.status_code,
) )
elif len(completion_response["content"]) == 0:
raise AnthropicError(
message="No content in response",
status_code=response.status_code,
)
else: else:
text_content = "" text_content = ""
tool_calls = [] tool_calls = []

View file

@ -1,4 +1,4 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any, Literal
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
@ -8,14 +8,16 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
convert_to_model_response_object, convert_to_model_response_object,
TranscriptionResponse, TranscriptionResponse,
get_secret,
) )
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid import uuid
import os
class AzureOpenAIError(Exception): class AzureOpenAIError(Exception):
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
return azure_client_params return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
if azure_client_id is None or azure_tenant is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
req_token = httpx.post(
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
possible_azure_ad_token = req_token.json().get("access_token", None)
if possible_azure_ad_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
)
return possible_azure_ad_token
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
headers["api-key"] = api_key headers["api-key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
@ -151,7 +200,7 @@ class AzureChatCompletion(BaseLLM):
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
print_verbose: Callable, print_verbose: Callable,
timeout, timeout: Union[float, httpx.Timeout],
logging_obj, logging_obj,
optional_params, optional_params,
litellm_params, litellm_params,
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True: if acompletion is True:
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client # setting Azure client
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
## LOGGING ## LOGGING
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True: if aimg_generation == True:
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None: if max_retries is not None:
@ -952,6 +1018,81 @@ class AzureChatCompletion(BaseLLM):
) )
raise e raise e
def get_headers(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(), # handle dall-e-2 calls
)
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if messages is None:
messages = [{"role": "user", "content": "Hey"}]
try:
completion = client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
except Exception as e:
raise e
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],
@ -963,7 +1104,7 @@ class AzureChatCompletion(BaseLLM):
messages: Optional[list] = None, messages: Optional[list] = None,
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
): ) -> dict:
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
) )
@ -1040,4 +1181,8 @@ class AzureChatCompletion(BaseLLM):
response["x-ratelimit-remaining-tokens"] = completion.headers[ response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens" "x-ratelimit-remaining-tokens"
] ]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response return response

View file

@ -1,5 +1,5 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, requests import types, requests # type: ignore
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,

View file

@ -1,7 +1,7 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage

View file

@ -4,7 +4,13 @@ from enum import Enum
import time, uuid import time, uuid
from typing import Callable, Optional, Any, Union, List from typing import Callable, Optional, Any, Union, List
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse from litellm.utils import (
ModelResponse,
get_secret,
Usage,
ImageResponse,
map_finish_reason,
)
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -157,6 +163,7 @@ class AmazonAnthropicClaude3Config:
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"extra_headers",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -524,6 +531,17 @@ class AmazonStabilityConfig:
} }
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client( def init_bedrock_client(
region_name=None, region_name=None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
@ -533,12 +551,13 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
timeout: Optional[int] = None, aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None) standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in ## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check # Define the list of parameters to check
params_to_check = [ params_to_check = [
@ -549,6 +568,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
] ]
# Iterate over parameters and update if needed # Iterate over parameters and update if needed
@ -564,6 +584,7 @@ def init_bedrock_client(
aws_session_name, aws_session_name,
aws_profile_name, aws_profile_name,
aws_role_name, aws_role_name,
aws_web_identity_token,
) = params_to_check ) = params_to_check
### SET REGION NAME ### SET REGION NAME
@ -592,10 +613,48 @@ def init_bedrock_client(
import boto3 import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config()
### CHECK STS ### ### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None: if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client(
"sts"
)
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in # use sts if role name passed in
sts_client = boto3.client( sts_client = boto3.client(
"sts", "sts",
@ -647,6 +706,10 @@ def init_bedrock_client(
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
) )
if extra_headers:
client.meta.events.register(
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
)
return client return client
@ -710,6 +773,7 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
timeout=None, timeout=None,
extra_headers: Optional[dict] = None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
_is_function_call = False _is_function_call = False
@ -725,6 +789,7 @@ def completion(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop("aws_bedrock_client", None) client = optional_params.pop("aws_bedrock_client", None)
@ -739,6 +804,8 @@ def completion(
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name, aws_profile_name=aws_profile_name,
aws_web_identity_token=aws_web_identity_token,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -1043,7 +1110,9 @@ def completion(
logging_obj=logging_obj, logging_obj=logging_obj,
) )
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = map_finish_reason(
response_body["stop_reason"]
)
_usage = litellm.Usage( _usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"], prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"], completion_tokens=response_body["usage"]["output_tokens"],
@ -1194,7 +1263,7 @@ def _embedding_func_single(
"input_type", "search_document" "input_type", "search_document"
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8") body = json.dumps(data).encode("utf-8") # type: ignore
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
@ -1258,6 +1327,7 @@ def embedding(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1265,6 +1335,7 @@ def embedding(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
) )
@ -1347,6 +1418,7 @@ def image_generation(
aws_bedrock_runtime_endpoint = optional_params.pop( aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) )
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one # use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client( client = init_bedrock_client(
@ -1354,6 +1426,7 @@ def image_generation(
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
timeout=timeout, timeout=timeout,
@ -1386,7 +1459,7 @@ def image_generation(
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
body={body}, body={body}, # type: ignore
modelId={modelId}, modelId={modelId},
accept="application/json", accept="application/json",
contentType="application/json", contentType="application/json",

View file

@ -1,11 +1,11 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
import httpx import httpx # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx # type: ignore
class CohereError(Exception): class CohereError(Exception):

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx # type: ignore
from .prompt_templates.factory import cohere_message_pt from .prompt_templates.factory import cohere_message_pt

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -162,16 +227,18 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference"
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational"
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return "text-generation"
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
self, self,
completion_response, completion_response,
model_response, model_response,
task, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response["choices"][0]["message"]["content"] = output_parser(
@ -322,9 +393,9 @@ class Huggingface(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
acompletion: bool = False, acompletion: bool = False,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = "" input_text = ""
@ -399,10 +476,11 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": ( "stream": ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and isinstance(optional_params["stream"], bool)
and optional_params["stream"] == True # type: ignore
else False else False
), ),
} }
@ -432,14 +510,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, }
"stream": ( if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False
), )
}
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -530,10 +609,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict) isinstance(completion_response, dict)
and "error" in completion_response and "error" in completion_response
): ):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
message=completion_response["error"], message=completion_response["error"], # type: ignore
status_code=response.status_code, status_code=response.status_code,
) )
return self.convert_to_model_response_object( return self.convert_to_model_response_object(
@ -562,7 +641,7 @@ class Huggingface(BaseLLM):
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
task: str, task: hf_tasks,
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm

View file

@ -1,9 +1,10 @@
import requests, types, time from itertools import chain
import requests, types, time # type: ignore
import json, uuid import json, uuid
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx, aiohttp, asyncio import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -212,25 +213,31 @@ def get_ollama_response(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop" model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json": if data.get("format", "") == "json":
function_call = json.loads(response_json["response"]) function_call = json.loads(response_json["response"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", ""))
)
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -255,6 +262,35 @@ def ollama_completion_stream(url, data, logging_obj):
custom_llm_provider="ollama", custom_llm_provider="ollama",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
@ -278,6 +314,38 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
custom_llm_provider="ollama", custom_llm_provider="ollama",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content
]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
@ -317,12 +385,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"]["content"] = response_json[ model_response["choices"][0]["message"]["content"] = response_json[
"response" "response"
@ -330,7 +402,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get(
"eval_count",
len(response_json.get("message", dict()).get("content", "")),
)
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -417,3 +492,25 @@ async def ollama_aembeddings(
"total_tokens": total_input_tokens, "total_tokens": total_input_tokens,
} }
return model_response return model_response
def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
return asyncio.run(
ollama_aembeddings(
api_base,
model,
prompts,
optional_params,
logging_obj,
model_response,
encoding,
)
)

View file

@ -1,3 +1,4 @@
from itertools import chain
import requests, types, time import requests, types, time
import json, uuid import json, uuid
import traceback import traceback
@ -297,6 +298,7 @@ def get_ollama_response(
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"] = response_json["message"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
@ -335,6 +337,33 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
custom_llm_provider="ollama_chat", custom_llm_provider="ollama_chat",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
@ -366,6 +395,34 @@ async def ollama_async_streaming(
custom_llm_provider="ollama_chat", custom_llm_provider="ollama_chat",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content]
)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
"type": "function",
}
],
)
model_response = first_chunk
model_response["choices"][0]["delta"] = delta
model_response["choices"][0]["finish_reason"] = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: except Exception as e:
@ -425,6 +482,7 @@ async def ollama_acompletion(
], ],
) )
model_response["choices"][0]["message"] = message model_response["choices"][0]["message"] = message
model_response["choices"][0]["finish_reason"] = "tool_calls"
else: else:
model_response["choices"][0]["message"] = response_json["message"] model_response["choices"][0]["message"] = response_json["message"]

View file

@ -1,7 +1,7 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage

View file

@ -1,4 +1,13 @@
from typing import Optional, Union, Any, BinaryIO from typing import (
Optional,
Union,
Any,
BinaryIO,
Literal,
Iterable,
)
from typing_extensions import override
from pydantic import BaseModel
import types, time, json, traceback import types, time, json, traceback
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
@ -13,10 +22,10 @@ from litellm.utils import (
TextCompletionResponse, TextCompletionResponse,
) )
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp, requests
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from ..types.llms.openai import *
class OpenAIError(Exception): class OpenAIError(Exception):
@ -246,7 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion( def completion(
self, self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
model: Optional[str] = None, model: Optional[str] = None,
messages: Optional[list] = None, messages: Optional[list] = None,
print_verbose: Optional[Callable] = None, print_verbose: Optional[Callable] = None,
@ -271,9 +280,12 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float): if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError( raise OpenAIError(
status_code=422, message=f"Timeout needs to be a float" status_code=422,
message=f"Timeout needs to be a float or httpx.Timeout",
) )
if custom_llm_provider != "openai": if custom_llm_provider != "openai":
@ -425,7 +437,7 @@ class OpenAIChatCompletion(BaseLLM):
self, self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float, timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -480,7 +492,7 @@ class OpenAIChatCompletion(BaseLLM):
def streaming( def streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -518,13 +530,14 @@ class OpenAIChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
return streamwrapper return streamwrapper
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj,
timeout: float, timeout: Union[float, httpx.Timeout],
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -567,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
return streamwrapper return streamwrapper
except ( except (
@ -1191,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
for chunk in streamwrapper: for chunk in streamwrapper:
@ -1229,7 +1244,228 @@ class OpenAITextCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> SyncCursorPage[Assistant]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.list()
return response
### MESSAGES ###
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAIMessage:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
thread_id, **message_data
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> SyncCursorPage[OpenAIMessage]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Thread:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
def delete_thread(self):
pass
### RUNS ###
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Run:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm

520
litellm/llms/predibase.py Normal file
View file

@ -0,0 +1,520 @@
# What is this?
## Controller file for Predibase Integration - https://predibase.com/
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class PredibaseError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.predibase.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
"""
adapter_id: Optional[str] = None
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: bool = True # enables returning logprobs + best of
max_new_tokens: int = (
256 # openai default - requests hang if max_new_tokens not given
)
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = [
"<|assistant|>",
"<|system|>",
"<|user|>",
"<s>",
"</s>",
]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict,
messages: list,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise PredibaseError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise PredibaseError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if (
not isinstance(completion_response, dict)
or "generated_text" not in completion_response
):
raise PredibaseError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = self.output_parser(
completion_response["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response
and "tokens" in completion_response["details"]
):
model_response.choices[0].finish_reason = completion_response[
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0][
"message"
]._logprob = (
sum_logprob # [TODO] move this to using the actual logprobs
)
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response
and "best_of_sequences" in completion_response["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=self.output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
) ##[TODO] use a model-specific tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use a model-specific tokenizer
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage # type: ignore
return model_response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: str,
logging_obj,
optional_params: dict,
tenant_id: str,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
if "https" in model:
completion_url = model
elif api_base:
base_url = api_base
elif "PREDIBASE_API_BASE" in os.environ:
base_url = os.getenv("PREDIBASE_API_BASE", "")
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
if optional_params.get("stream", False) == True:
completion_url += "/generate_stream"
else:
completion_url += "/generate"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## Load Config
config = litellm.PredibaseConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
data = {
"inputs": prompt,
"parameters": optional_params,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if stream == True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
) # type: ignore
else:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
) # type: ignore
### SYNC STREAMING
if stream == True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=stream,
)
_response = CustomStreamWrapper(
response.iter_lines(),
model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
)
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj, # type: ignore
optional_params=optional_params,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(data),
stream=True,
)
if response.status_code != 200:
raise PredibaseError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
)
return streamwrapper
def embedding(self, *args, **kwargs):
pass

View file

@ -12,6 +12,16 @@ from typing import (
Sequence, Sequence,
) )
import litellm import litellm
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
from litellm.types.llms.anthropic import *
import uuid
def default_pt(messages): def default_pt(messages):
@ -22,6 +32,41 @@ def prompt_injection_detection_default_pt():
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not.""" return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
def map_system_message_pt(messages: list) -> list:
"""
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
Enabled via `completion(...,supports_system_message=False)`
If next message is a user message or assistant message -> merge system prompt into it
if next message is system -> append a user message instead of the system message
"""
new_messages = []
for i, m in enumerate(messages):
if m["role"] == "system":
if i < len(messages) - 1: # Not the last message
next_m = messages[i + 1]
next_role = next_m["role"]
if (
next_role == "user" or next_role == "assistant"
): # Next message is a user or assistant message
# Merge system prompt into the next message
next_m["content"] = m["content"] + " " + next_m["content"]
elif next_role == "system": # Next message is a system message
# Append a user message instead of the system message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Last message
new_message = {"role": "user", "content": m["content"]}
new_messages.append(new_message)
else: # Not a system message
new_messages.append(m)
return new_messages
# alpaca prompt template - for models like mythomax, etc. # alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages): def alpaca_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
@ -805,6 +850,13 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
"name": "get_current_weather", "name": "get_current_weather",
"content": "function result goes here", "content": "function result goes here",
}, },
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
""" """
""" """
@ -821,6 +873,7 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
] ]
} }
""" """
if message["role"] == "tool":
tool_call_id = message.get("tool_call_id") tool_call_id = message.get("tool_call_id")
content = message.get("content") content = message.get("content")
@ -831,8 +884,31 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
"tool_use_id": tool_call_id, "tool_use_id": tool_call_id,
"content": content, "content": content,
} }
return anthropic_tool_result return anthropic_tool_result
elif message["role"] == "function":
content = message.get("content")
anthropic_tool_result = {
"type": "tool_result",
"tool_use_id": str(uuid.uuid4()),
"content": content,
}
return anthropic_tool_result
return {}
def convert_function_to_anthropic_tool_invoke(function_call):
try:
anthropic_tool_invoke = [
{
"type": "tool_use",
"id": str(uuid.uuid4()),
"name": get_attribute_or_key(function_call, "name"),
"input": json.loads(get_attribute_or_key(function_call, "arguments")),
}
]
return anthropic_tool_invoke
except Exception as e:
raise e
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list: def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
@ -895,7 +971,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
def anthropic_messages_pt(messages: list): def anthropic_messages_pt(messages: list):
""" """
format messages for anthropic format messages for anthropic
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant) 1. Anthropic supports roles like "user" and "assistant" (system prompt sent separately)
2. The first message always needs to be of role "user" 2. The first message always needs to be of role "user"
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm) 3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise) 4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
@ -903,12 +979,14 @@ def anthropic_messages_pt(messages: list):
6. Ensure we only accept role, content. (message.name is not supported) 6. Ensure we only accept role, content. (message.name is not supported)
""" """
# add role=tool support to allow function call result/error submission # add role=tool support to allow function call result/error submission
user_message_types = {"user", "tool"} user_message_types = {"user", "tool", "function"}
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them. # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
new_messages = [] new_messages: list = []
msg_i = 0 msg_i = 0
tool_use_param = False
while msg_i < len(messages): while msg_i < len(messages):
user_content = [] user_content = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ## ## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list): if isinstance(messages[msg_i]["content"], list):
@ -924,7 +1002,10 @@ def anthropic_messages_pt(messages: list):
) )
elif m.get("type", "") == "text": elif m.get("type", "") == "text":
user_content.append({"type": "text", "text": m["text"]}) user_content.append({"type": "text", "text": m["text"]})
elif messages[msg_i]["role"] == "tool": elif (
messages[msg_i]["role"] == "tool"
or messages[msg_i]["role"] == "function"
):
# OpenAI's tool message content will always be a string # OpenAI's tool message content will always be a string
user_content.append(convert_to_anthropic_tool_result(messages[msg_i])) user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
else: else:
@ -953,11 +1034,24 @@ def anthropic_messages_pt(messages: list):
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"]) convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
) )
if messages[msg_i].get("function_call"):
assistant_content.extend(
convert_function_to_anthropic_tool_invoke(
messages[msg_i]["function_call"]
)
)
msg_i += 1 msg_i += 1
if assistant_content: if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content}) new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
if not new_messages or new_messages[0]["role"] != "user": if not new_messages or new_messages[0]["role"] != "user":
if litellm.modify_params: if litellm.modify_params:
new_messages.insert( new_messages.insert(
@ -969,6 +1063,9 @@ def anthropic_messages_pt(messages: list):
) )
if new_messages[-1]["role"] == "assistant": if new_messages[-1]["role"] == "assistant":
if isinstance(new_messages[-1]["content"], str):
new_messages[-1]["content"] = new_messages[-1]["content"].rstrip()
elif isinstance(new_messages[-1]["content"], list):
for content in new_messages[-1]["content"]: for content in new_messages[-1]["content"]:
if isinstance(content, dict) and content["type"] == "text": if isinstance(content, dict) and content["type"] == "text":
content["text"] = content[ content["text"] = content[

View file

@ -1,11 +1,11 @@
import os, types import os, types
import json import json
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -1,14 +1,14 @@
import os, types, traceback import os, types, traceback
from enum import Enum from enum import Enum
import json import json
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional, Any from typing import Callable, Optional, Any
import litellm import litellm
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys import sys
from copy import deepcopy from copy import deepcopy
import httpx import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -295,7 +295,7 @@ def completion(
EndpointName={model}, EndpointName={model},
InferenceComponentName={model_id}, InferenceComponentName={model_id},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
""" # type: ignore """ # type: ignore
@ -321,7 +321,7 @@ def completion(
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName={model}, EndpointName={model},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
""" # type: ignore """ # type: ignore
@ -688,7 +688,7 @@ def embedding(
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName={model}, EndpointName={model},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
)""" # type: ignore )""" # type: ignore
logging_obj.pre_call( logging_obj.pre_call(

View file

@ -6,11 +6,11 @@ Reference: https://docs.together.ai/docs/openai-api-compatibility
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
import httpx import httpx # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

119
litellm/llms/triton.py Normal file
View file

@ -0,0 +1,119 @@
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TritonError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST",
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class TritonChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
async def aembedding(
self,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj=None,
api_key: Optional[str] = None,
):
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
if response.status_code != 200:
raise TritonError(status_code=response.status_code, message=response.text)
_text_response = response.text
logging_obj.post_call(original_response=_text_response)
_json_response = response.json()
_outputs = _json_response["outputs"]
_output_data = _outputs[0]["data"]
_embedding_output = {
"object": "embedding",
"index": 0,
"embedding": _output_data,
}
model_response.model = _json_response.get("model_name", "None")
model_response.data = [_embedding_output]
return model_response
def embedding(
self,
model: str,
input: list,
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aembedding=None,
):
data_for_triton = {
"inputs": [
{
"name": "input_text",
"shape": [1],
"datatype": "BYTES",
"data": input,
}
]
}
## LOGGING
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
logging_obj.pre_call(
input="",
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": curl_string,
},
)
if aembedding == True:
response = self.aembedding(
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
)
return response
else:
raise Exception(
"Only async embedding supported for triton, please use litellm.aembedding() for now"
)

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List from typing import Callable, Optional, Union, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx, inspect import httpx, inspect # type: ignore
class VertexAIError(Exception): class VertexAIError(Exception):
@ -419,6 +419,7 @@ def completion(
from google.protobuf.struct_pb2 import Value # type: ignore from google.protobuf.struct_pb2 import Value # type: ignore
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore import google.auth # type: ignore
import proto # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -605,9 +606,21 @@ def completion(
): ):
function_call = response.candidates[0].content.parts[0].function_call function_call = response.candidates[0].content.parts[0].function_call
args_dict = {} args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v # Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
):
# If so, convert to list
args_dict[key] = [v for v in val]
else:
args_dict[key] = val
try:
args_str = json.dumps(args_dict) args_str = json.dumps(args_dict)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
message = litellm.Message( message = litellm.Message(
content=None, content=None,
tool_calls=[ tool_calls=[
@ -810,6 +823,8 @@ def completion(
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
except Exception as e: except Exception as e:
if isinstance(e, VertexAIError):
raise e
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))

View file

@ -3,7 +3,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy # type: ignore
import time, uuid import time, uuid
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
@ -17,7 +17,7 @@ from .prompt_templates.factory import (
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
) )
import httpx import httpx # type: ignore
class VertexAIError(Exception): class VertexAIError(Exception):

View file

@ -1,8 +1,8 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, httpx import time, httpx # type: ignore
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -3,8 +3,8 @@ import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import Callable, Dict, Optional, Any, Union, List
import httpx import httpx # type: ignore
import requests import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import ModelResponse, get_secret, Usage

View file

@ -14,7 +14,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
@ -34,9 +33,12 @@ from litellm.utils import (
async_mock_completion_streaming_obj, async_mock_completion_streaming_obj,
convert_to_model_response_object, convert_to_model_response_object,
token_counter, token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage, Usage,
get_optional_params_embeddings, get_optional_params_embeddings,
get_optional_params_image_gen, get_optional_params_image_gen,
supports_httpx_timeout,
) )
from .llms import ( from .llms import (
anthropic_text, anthropic_text,
@ -44,6 +46,7 @@ from .llms import (
ai21, ai21,
sagemaker, sagemaker,
bedrock, bedrock,
triton,
huggingface_restapi, huggingface_restapi,
replicate, replicate,
aleph_alpha, aleph_alpha,
@ -71,10 +74,13 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
map_system_message_pt,
) )
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -106,6 +112,8 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -184,6 +192,7 @@ async def acompletion(
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
@ -203,6 +212,7 @@ async def acompletion(
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
extra_headers: Optional[dict] = None,
# Optional liteLLM function params # Optional liteLLM function params
**kwargs, **kwargs,
): ):
@ -220,6 +230,7 @@ async def acompletion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1). n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False). stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -257,6 +268,7 @@ async def acompletion(
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
"stream": stream, "stream": stream,
"stream_options": stream_options,
"stop": stop, "stop": stop,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"presence_penalty": presence_penalty, "presence_penalty": presence_penalty,
@ -301,6 +313,7 @@ async def acompletion(
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
@ -309,6 +322,7 @@ async def acompletion(
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -448,11 +462,12 @@ def completion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [], messages: List = [],
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
@ -492,6 +507,7 @@ def completion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1). n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False). stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -551,6 +567,7 @@ def completion(
eos_token = kwargs.get("eos_token", None) eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None) preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
### TEXT COMPLETION CALLS ### ### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False) text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False) atext_completion = kwargs.get("atext_completion", False)
@ -568,6 +585,7 @@ def completion(
"top_p", "top_p",
"n", "n",
"stream", "stream",
"stream_options",
"stop", "stop",
"max_tokens", "max_tokens",
"presence_penalty", "presence_penalty",
@ -616,6 +634,7 @@ def completion(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",
@ -641,16 +660,30 @@ def completion(
"no-log", "no-log",
"base_model", "base_model",
"stream_timeout", "stream_timeout",
"supports_system_message",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
if timeout is None:
timeout = ( ### TIMEOUT LOGIC ###
kwargs.get("request_timeout", None) or 600 timeout = timeout or kwargs.get("request_timeout", 600) or 600
) # set timeout for 10 minutes by default # set timeout for 10 minutes by default
timeout = float(timeout)
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try: try:
if base_url is not None: if base_url is not None:
api_base = base_url api_base = base_url
@ -745,6 +778,13 @@ def completion(
custom_prompt_dict[model]["bos_token"] = bos_token custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token: if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token custom_prompt_dict[model]["eos_token"] = eos_token
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key( model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model ) # get the api key from the environment if required for the model
@ -759,6 +799,7 @@ def completion(
top_p=top_p, top_p=top_p,
n=n, n=n,
stream=stream, stream=stream,
stream_options=stream_options,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@ -871,7 +912,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
) )
@ -958,6 +999,7 @@ def completion(
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "openai" or custom_llm_provider == "openai"
@ -1012,7 +1054,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization, organization=organization,
@ -1097,7 +1139,7 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
@ -1471,7 +1513,7 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -1564,7 +1606,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
timeout=timeout, timeout=timeout, # type: ignore
) )
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -1749,6 +1791,52 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "predibase":
tenant_id = (
optional_params.pop("tenant_id", None)
or optional_params.pop("predibase_tenant_id", None)
or litellm.predibase_tenant_id
or get_secret("PREDIBASE_TENANT_ID")
)
api_base = (
optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or get_secret("PREDIBASE_API_BASE")
)
api_key = (
api_key
or litellm.api_key
or litellm.predibase_key
or get_secret("PREDIBASE_API_KEY")
)
_model_response = predibase_chat_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
tenant_id=tenant_id,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
return _model_response
response = _model_response
elif custom_llm_provider == "ai21": elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21" custom_llm_provider = "ai21"
ai21_key = ( ai21_key = (
@ -1844,6 +1932,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )
@ -1891,7 +1980,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
timeout=timeout, timeout=timeout, # type: ignore
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -2272,7 +2361,7 @@ def batch_completion(
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stop=None, stop=None,
max_tokens: Optional[float] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
@ -2535,11 +2624,13 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "voyage" or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter" or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
@ -2665,6 +2756,7 @@ def embedding(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",
@ -2688,6 +2780,8 @@ def embedding(
"ttl", "ttl",
"cache", "cache",
"no-log", "no-log",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -2864,23 +2958,43 @@ def embedding(
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "triton":
if api_base is None:
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret("VERTEXAI_PROJECT")
or get_secret("VERTEX_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.pop("vertex_location", None) optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None) or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
) )
vertex_credentials = ( vertex_credentials = (
optional_params.pop("vertex_credentials", None) optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
or get_secret("VERTEX_CREDENTIALS")
) )
response = vertex_ai.embedding( response = vertex_ai.embedding(
@ -2921,8 +3035,10 @@ def embedding(
model=model, # type: ignore model=model, # type: ignore
llm_provider="ollama", # type: ignore llm_provider="ollama", # type: ignore
) )
if aembedding: ollama_embeddings_fn = (
response = ollama.ollama_aembeddings( ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
)
response = ollama_embeddings_fn(
api_base=api_base, api_base=api_base,
model=model, model=model,
prompts=input, prompts=input,
@ -3059,11 +3175,13 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
@ -3094,6 +3212,8 @@ async def atext_completion(*args, **kwargs):
## TRANSLATE CHAT TO TEXT FORMAT ## ## TRANSLATE CHAT TO TEXT FORMAT ##
if isinstance(response, TextCompletionResponse): if isinstance(response, TextCompletionResponse):
return response return response
elif asyncio.iscoroutine(response):
response = await response
text_completion_response = TextCompletionResponse() text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None) text_completion_response["id"] = response.get("id", None)
@ -3153,6 +3273,7 @@ def text_completion(
Union[str, List[str]] Union[str, List[str]]
] = None, # Optional: Sequences where the API will stop generating further tokens. ] = None, # Optional: Sequences where the API will stop generating further tokens.
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
stream_options: Optional[dict] = None,
suffix: Optional[ suffix: Optional[
str str
] = None, # Optional: The suffix that comes after a completion of inserted text. ] = None, # Optional: The suffix that comes after a completion of inserted text.
@ -3230,6 +3351,8 @@ def text_completion(
optional_params["stop"] = stop optional_params["stop"] = stop
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if suffix is not None: if suffix is not None:
optional_params["suffix"] = suffix optional_params["suffix"] = suffix
if temperature is not None: if temperature is not None:
@ -3340,7 +3463,9 @@ def text_completion(
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) == True:
return response return response
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(
completion_stream=response, model=model, stream_options=stream_options
)
return response return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
@ -3534,6 +3659,7 @@ def image_generation(
"model_list", "model_list",
"num_retries", "num_retries",
"context_window_fallback_dict", "context_window_fallback_dict",
"retry_policy",
"roles", "roles",
"final_prompt_value", "final_prompt_value",
"bos_token", "bos_token",
@ -3554,6 +3680,8 @@ def image_generation(
"caching_groups", "caching_groups",
"ttl", "ttl",
"cache", "cache",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {

View file

@ -338,6 +338,18 @@
"output_cost_per_second": 0.0001, "output_cost_per_second": 0.0001,
"litellm_provider": "azure" "litellm_provider": "azure"
}, },
"azure/gpt-4-turbo-2024-04-09": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00001,
"output_cost_per_token": 0.00003,
"litellm_provider": "azure",
"mode": "chat",
"supports_function_calling": true,
"supports_parallel_function_calling": true,
"supports_vision": true
},
"azure/gpt-4-0125-preview": { "azure/gpt-4-0125-preview": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,
@ -727,6 +739,24 @@
"litellm_provider": "mistral", "litellm_provider": "mistral",
"mode": "embedding" "mode": "embedding"
}, },
"deepseek-chat": {
"max_tokens": 4096,
"max_input_tokens": 32000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000014,
"output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek",
"mode": "chat"
},
"deepseek-coder": {
"max_tokens": 4096,
"max_input_tokens": 16000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000014,
"output_cost_per_token": 0.00000028,
"litellm_provider": "deepseek",
"mode": "chat"
},
"groq/llama2-70b-4096": { "groq/llama2-70b-4096": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,
@ -813,6 +843,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 264 "tool_use_system_prompt_tokens": 264
}, },
"claude-3-opus-20240229": { "claude-3-opus-20240229": {
@ -824,6 +855,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"claude-3-sonnet-20240229": { "claude-3-sonnet-20240229": {
@ -835,6 +867,7 @@
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159 "tool_use_system_prompt_tokens": 159
}, },
"text-bison": { "text-bison": {
@ -1045,8 +1078,8 @@
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
"max_output_tokens": 8192, "max_output_tokens": 8192,
"input_cost_per_token": 0, "input_cost_per_token": 0.000000625,
"output_cost_per_token": 0, "output_cost_per_token": 0.000001875,
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
@ -1057,8 +1090,8 @@
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
"max_output_tokens": 8192, "max_output_tokens": 8192,
"input_cost_per_token": 0, "input_cost_per_token": 0.000000625,
"output_cost_per_token": 0, "output_cost_per_token": 0.000001875,
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
@ -1069,8 +1102,8 @@
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
"max_output_tokens": 8192, "max_output_tokens": 8192,
"input_cost_per_token": 0, "input_cost_per_token": 0.000000625,
"output_cost_per_token": 0, "output_cost_per_token": 0.000001875,
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
@ -1142,7 +1175,8 @@
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1152,7 +1186,8 @@
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"vertex_ai/claude-3-opus@20240229": { "vertex_ai/claude-3-opus@20240229": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1162,7 +1197,8 @@
"output_cost_per_token": 0.0000075, "output_cost_per_token": 0.0000075,
"litellm_provider": "vertex_ai-anthropic_models", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"textembedding-gecko": { "textembedding-gecko": {
"max_tokens": 3072, "max_tokens": 3072,
@ -1581,6 +1617,7 @@
"litellm_provider": "openrouter", "litellm_provider": "openrouter",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 395 "tool_use_system_prompt_tokens": 395
}, },
"openrouter/google/palm-2-chat-bison": { "openrouter/google/palm-2-chat-bison": {
@ -1813,6 +1850,15 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "embedding" "mode": "embedding"
}, },
"amazon.titan-embed-text-v2:0": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"output_vector_size": 1024,
"input_cost_per_token": 0.0000002,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "embedding"
},
"mistral.mistral-7b-instruct-v0:2": { "mistral.mistral-7b-instruct-v0:2": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 32000, "max_input_tokens": 32000,
@ -1929,7 +1975,8 @@
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-3-haiku-20240307-v1:0": { "anthropic.claude-3-haiku-20240307-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1939,7 +1986,8 @@
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-3-opus-20240229-v1:0": { "anthropic.claude-3-opus-20240229-v1:0": {
"max_tokens": 4096, "max_tokens": 4096,
@ -1949,7 +1997,8 @@
"output_cost_per_token": 0.000075, "output_cost_per_token": 0.000075,
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat", "mode": "chat",
"supports_function_calling": true "supports_function_calling": true,
"supports_vision": true
}, },
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 8191, "max_tokens": 8191,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{87421:function(n,e,t){Promise.resolve().then(t.t.bind(t,99646,23)),Promise.resolve().then(t.t.bind(t,63385,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_c23dc8', '__Inter_Fallback_c23dc8'",fontStyle:"normal"},className:"__className_c23dc8"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=87421)}),_N_E=n.O()}]);

View file

@ -1 +0,0 @@
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{93553:function(n,e,t){Promise.resolve().then(t.t.bind(t,63385,23)),Promise.resolve().then(t.t.bind(t,99646,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_12bbc4', '__Inter_Fallback_12bbc4'",fontStyle:"normal"},className:"__className_12bbc4"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=93553)}),_N_E=n.O()}]);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more