forked from phoenix/litellm-mirror
Merge branch 'main' into feat/add-azure-content-filter
This commit is contained in:
commit
bbe1300c5b
200 changed files with 19218 additions and 2966 deletions
|
@ -1,4 +1,4 @@
|
||||||
version: 2.1
|
version: 4.3.4
|
||||||
jobs:
|
jobs:
|
||||||
local_testing:
|
local_testing:
|
||||||
docker:
|
docker:
|
||||||
|
@ -188,7 +188,7 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
docker run -d \
|
docker run -d \
|
||||||
-p 4000:4000 \
|
-p 4000:4000 \
|
||||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
-e REDIS_HOST=$REDIS_HOST \
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
@ -198,6 +198,7 @@ jobs:
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
|
-e AUTO_INFER_REGION=True \
|
||||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||||
|
@ -208,9 +209,7 @@ jobs:
|
||||||
my-app:latest \
|
my-app:latest \
|
||||||
--config /app/config.yaml \
|
--config /app/config.yaml \
|
||||||
--port 4000 \
|
--port 4000 \
|
||||||
--num_workers 8 \
|
|
||||||
--detailed_debug \
|
--detailed_debug \
|
||||||
--run_gunicorn \
|
|
||||||
- run:
|
- run:
|
||||||
name: Install curl and dockerize
|
name: Install curl and dockerize
|
||||||
command: |
|
command: |
|
||||||
|
@ -225,7 +224,7 @@ jobs:
|
||||||
background: true
|
background: true
|
||||||
- run:
|
- run:
|
||||||
name: Wait for app to be ready
|
name: Wait for app to be ready
|
||||||
command: dockerize -wait http://localhost:4000 -timeout 1m
|
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
|
|
51
.devcontainer/devcontainer.json
Normal file
51
.devcontainer/devcontainer.json
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"name": "Python 3.11",
|
||||||
|
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||||
|
"image": "mcr.microsoft.com/devcontainers/python:3.11-bookworm",
|
||||||
|
// https://github.com/devcontainers/images/tree/main/src/python
|
||||||
|
// https://mcr.microsoft.com/en-us/product/devcontainers/python/tags
|
||||||
|
|
||||||
|
// "build": {
|
||||||
|
// "dockerfile": "Dockerfile",
|
||||||
|
// "context": ".."
|
||||||
|
// },
|
||||||
|
|
||||||
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
|
// "features": {},
|
||||||
|
|
||||||
|
// Configure tool-specific properties.
|
||||||
|
"customizations": {
|
||||||
|
// Configure properties specific to VS Code.
|
||||||
|
"vscode": {
|
||||||
|
"settings": {},
|
||||||
|
"extensions": [
|
||||||
|
"ms-python.python",
|
||||||
|
"ms-python.vscode-pylance",
|
||||||
|
"GitHub.copilot",
|
||||||
|
"GitHub.copilot-chat"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
|
"forwardPorts": [4000],
|
||||||
|
|
||||||
|
"containerEnv": {
|
||||||
|
"LITELLM_LOG": "DEBUG"
|
||||||
|
},
|
||||||
|
|
||||||
|
// Use 'portsAttributes' to set default properties for specific forwarded ports.
|
||||||
|
// More info: https://containers.dev/implementors/json_reference/#port-attributes
|
||||||
|
"portsAttributes": {
|
||||||
|
"4000": {
|
||||||
|
"label": "LiteLLM Server",
|
||||||
|
"onAutoForward": "notify"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// More info: https://aka.ms/dev-containers-non-root.
|
||||||
|
// "remoteUser": "litellm",
|
||||||
|
|
||||||
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
|
"postCreateCommand": "pipx install poetry && poetry install -E extra_proxy -E proxy"
|
||||||
|
}
|
19
.github/workflows/interpret_load_test.py
vendored
19
.github/workflows/interpret_load_test.py
vendored
|
@ -64,6 +64,11 @@ if __name__ == "__main__":
|
||||||
) # Replace with your repository's username and name
|
) # Replace with your repository's username and name
|
||||||
latest_release = repo.get_latest_release()
|
latest_release = repo.get_latest_release()
|
||||||
print("got latest release: ", latest_release)
|
print("got latest release: ", latest_release)
|
||||||
|
print(latest_release.title)
|
||||||
|
print(latest_release.tag_name)
|
||||||
|
|
||||||
|
release_version = latest_release.title
|
||||||
|
|
||||||
print("latest release body: ", latest_release.body)
|
print("latest release body: ", latest_release.body)
|
||||||
print("markdown table: ", markdown_table)
|
print("markdown table: ", markdown_table)
|
||||||
|
|
||||||
|
@ -74,8 +79,22 @@ if __name__ == "__main__":
|
||||||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||||
existing_release_body = latest_release.body[:start_index]
|
existing_release_body = latest_release.body[:start_index]
|
||||||
|
|
||||||
|
docker_run_command = f"""
|
||||||
|
\n\n
|
||||||
|
## Docker Run LiteLLM Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run \\
|
||||||
|
-e STORE_MODEL_IN_DB=True \\
|
||||||
|
-p 4000:4000 \\
|
||||||
|
ghcr.io/berriai/litellm:main-{release_version}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
print("docker run command: ", docker_run_command)
|
||||||
|
|
||||||
new_release_body = (
|
new_release_body = (
|
||||||
existing_release_body
|
existing_release_body
|
||||||
|
+ docker_run_command
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
||||||
.venv
|
.venv
|
||||||
.env
|
.env
|
||||||
|
litellm/proxy/myenv/*
|
||||||
litellm_uuid.txt
|
litellm_uuid.txt
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_new_secret_config.yaml
|
litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
|
litellm/proxy/myenv/bin/activate
|
||||||
|
litellm/proxy/myenv/bin/Activate.ps1
|
||||||
|
myenv/*
|
|
@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
||||||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||||
|
|
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
Binary file not shown.
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"$schema": "https://schema.management.azure.com/schemas/0.1.2-preview/CreateUIDefinition.MultiVm.json#",
|
||||||
|
"handler": "Microsoft.Azure.CreateUIDef",
|
||||||
|
"version": "0.1.2-preview",
|
||||||
|
"parameters": {
|
||||||
|
"config": {
|
||||||
|
"isWizard": false,
|
||||||
|
"basics": { }
|
||||||
|
},
|
||||||
|
"basics": [ ],
|
||||||
|
"steps": [ ],
|
||||||
|
"outputs": { },
|
||||||
|
"resourceTypes": [ ]
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
{
|
||||||
|
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
|
||||||
|
"contentVersion": "1.0.0.0",
|
||||||
|
"parameters": {
|
||||||
|
"imageName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "ghcr.io/berriai/litellm:main-latest"
|
||||||
|
},
|
||||||
|
"containerName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "litellm-container"
|
||||||
|
},
|
||||||
|
"dnsLabelName": {
|
||||||
|
"type": "string",
|
||||||
|
"defaultValue": "litellm"
|
||||||
|
},
|
||||||
|
"portNumber": {
|
||||||
|
"type": "int",
|
||||||
|
"defaultValue": 4000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"resources": [
|
||||||
|
{
|
||||||
|
"type": "Microsoft.ContainerInstance/containerGroups",
|
||||||
|
"apiVersion": "2021-03-01",
|
||||||
|
"name": "[parameters('containerName')]",
|
||||||
|
"location": "[resourceGroup().location]",
|
||||||
|
"properties": {
|
||||||
|
"containers": [
|
||||||
|
{
|
||||||
|
"name": "[parameters('containerName')]",
|
||||||
|
"properties": {
|
||||||
|
"image": "[parameters('imageName')]",
|
||||||
|
"resources": {
|
||||||
|
"requests": {
|
||||||
|
"cpu": 1,
|
||||||
|
"memoryInGB": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ports": [
|
||||||
|
{
|
||||||
|
"port": "[parameters('portNumber')]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"osType": "Linux",
|
||||||
|
"restartPolicy": "Always",
|
||||||
|
"ipAddress": {
|
||||||
|
"type": "Public",
|
||||||
|
"ports": [
|
||||||
|
{
|
||||||
|
"protocol": "tcp",
|
||||||
|
"port": "[parameters('portNumber')]"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"dnsNameLabel": "[parameters('dnsLabelName')]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
42
deploy/azure_resource_manager/main.bicep
Normal file
42
deploy/azure_resource_manager/main.bicep
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
param imageName string = 'ghcr.io/berriai/litellm:main-latest'
|
||||||
|
param containerName string = 'litellm-container'
|
||||||
|
param dnsLabelName string = 'litellm'
|
||||||
|
param portNumber int = 4000
|
||||||
|
|
||||||
|
resource containerGroupName 'Microsoft.ContainerInstance/containerGroups@2021-03-01' = {
|
||||||
|
name: containerName
|
||||||
|
location: resourceGroup().location
|
||||||
|
properties: {
|
||||||
|
containers: [
|
||||||
|
{
|
||||||
|
name: containerName
|
||||||
|
properties: {
|
||||||
|
image: imageName
|
||||||
|
resources: {
|
||||||
|
requests: {
|
||||||
|
cpu: 1
|
||||||
|
memoryInGB: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ports: [
|
||||||
|
{
|
||||||
|
port: portNumber
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
osType: 'Linux'
|
||||||
|
restartPolicy: 'Always'
|
||||||
|
ipAddress: {
|
||||||
|
type: 'Public'
|
||||||
|
ports: [
|
||||||
|
{
|
||||||
|
protocol: 'tcp'
|
||||||
|
port: portNumber
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dnsNameLabel: dnsLabelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,7 +24,7 @@ version: 0.2.0
|
||||||
# incremented each time you make changes to the application. Versions are not expected to
|
# incremented each time you make changes to the application. Versions are not expected to
|
||||||
# follow Semantic Versioning. They should reflect the version the application is using.
|
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||||
# It is recommended to use it with quotes.
|
# It is recommended to use it with quotes.
|
||||||
appVersion: v1.24.5
|
appVersion: v1.35.38
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
- name: "postgresql"
|
- name: "postgresql"
|
||||||
|
|
|
@ -83,8 +83,9 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[float] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
|
@ -139,6 +140,10 @@ def completion(
|
||||||
|
|
||||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||||
|
|
||||||
|
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||||
|
|
||||||
|
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
|
||||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Completion Token Usage & Cost
|
# Completion Token Usage & Cost
|
||||||
By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/))
|
By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/))
|
||||||
|
|
||||||
However, we also expose 5 helper functions + **[NEW]** an API to calculate token usage across providers:
|
However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers:
|
||||||
|
|
||||||
- `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode)
|
- `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode)
|
||||||
|
|
||||||
|
@ -9,17 +9,19 @@ However, we also expose 5 helper functions + **[NEW]** an API to calculate token
|
||||||
|
|
||||||
- `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter)
|
- `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter)
|
||||||
|
|
||||||
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#4-cost_per_token)
|
- `create_pretrained_tokenizer` and `create_tokenizer`: LiteLLM provides default tokenizer support for OpenAI, Cohere, Anthropic, Llama2, and Llama3 models. If you are using a different model, you can create a custom tokenizer and pass it as `custom_tokenizer` to the `encode`, `decode`, and `token_counter` methods. [**Jump to code**](#4-create_pretrained_tokenizer-and-create_tokenizer)
|
||||||
|
|
||||||
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#5-completion_cost)
|
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#5-cost_per_token)
|
||||||
|
|
||||||
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#6-get_max_tokens)
|
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#6-completion_cost)
|
||||||
|
|
||||||
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#7-model_cost)
|
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#7-get_max_tokens)
|
||||||
|
|
||||||
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#8-register_model)
|
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#8-model_cost)
|
||||||
|
|
||||||
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#9-apilitellmai)
|
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#9-register_model)
|
||||||
|
|
||||||
|
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai)
|
||||||
|
|
||||||
📣 This is a community maintained list. Contributions are welcome! ❤️
|
📣 This is a community maintained list. Contributions are welcome! ❤️
|
||||||
|
|
||||||
|
@ -60,7 +62,24 @@ messages = [{"user": "role", "content": "Hey, how's it going"}]
|
||||||
print(token_counter(model="gpt-3.5-turbo", messages=messages))
|
print(token_counter(model="gpt-3.5-turbo", messages=messages))
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. `cost_per_token`
|
### 4. `create_pretrained_tokenizer` and `create_tokenizer`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import create_pretrained_tokenizer, create_tokenizer
|
||||||
|
|
||||||
|
# get tokenizer from huggingface repo
|
||||||
|
custom_tokenizer_1 = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
|
||||||
|
|
||||||
|
# use tokenizer from json file
|
||||||
|
with open("tokenizer.json") as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
|
||||||
|
json_str = json.dumps(json_data)
|
||||||
|
|
||||||
|
custom_tokenizer_2 = create_tokenizer(json_str)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. `cost_per_token`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import cost_per_token
|
from litellm import cost_per_token
|
||||||
|
@ -72,7 +91,7 @@ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_toke
|
||||||
print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar)
|
print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. `completion_cost`
|
### 6. `completion_cost`
|
||||||
|
|
||||||
* Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings
|
* Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings
|
||||||
* Output: Returns a `float` of cost for the `completion` call
|
* Output: Returns a `float` of cost for the `completion` call
|
||||||
|
@ -99,7 +118,7 @@ cost = completion_cost(model="bedrock/anthropic.claude-v2", prompt="Hey!", compl
|
||||||
formatted_string = f"${float(cost):.10f}"
|
formatted_string = f"${float(cost):.10f}"
|
||||||
print(formatted_string)
|
print(formatted_string)
|
||||||
```
|
```
|
||||||
### 6. `get_max_tokens`
|
### 7. `get_max_tokens`
|
||||||
|
|
||||||
Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list).
|
Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list).
|
||||||
Output: Returns the maximum number of tokens allowed for the given model
|
Output: Returns the maximum number of tokens allowed for the given model
|
||||||
|
@ -112,7 +131,7 @@ model = "gpt-3.5-turbo"
|
||||||
print(get_max_tokens(model)) # Output: 4097
|
print(get_max_tokens(model)) # Output: 4097
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. `model_cost`
|
### 8. `model_cost`
|
||||||
|
|
||||||
* Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
* Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||||
|
|
||||||
|
@ -122,7 +141,7 @@ from litellm import model_cost
|
||||||
print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...}
|
print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 8. `register_model`
|
### 9. `register_model`
|
||||||
|
|
||||||
* Input: Provide EITHER a model cost dictionary or a url to a hosted json blob
|
* Input: Provide EITHER a model cost dictionary or a url to a hosted json blob
|
||||||
* Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details.
|
* Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details.
|
||||||
|
@ -157,5 +176,3 @@ export LITELLM_LOCAL_MODEL_COST_MAP="True"
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: this means you will need to upgrade to get updated pricing, and newer models.
|
Note: this means you will need to upgrade to get updated pricing, and newer models.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -320,8 +320,6 @@ from litellm import embedding
|
||||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||||
litellm.vertex_location = "us-central1" # proj location
|
litellm.vertex_location = "us-central1" # proj location
|
||||||
|
|
||||||
|
|
||||||
os.environ['VOYAGE_API_KEY'] = ""
|
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="vertex_ai/textembedding-gecko",
|
model="vertex_ai/textembedding-gecko",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
|
|
|
@ -17,6 +17,14 @@ This covers:
|
||||||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
||||||
|
|
||||||
|
|
||||||
|
## [COMING SOON] AWS Marketplace Support
|
||||||
|
|
||||||
|
Deploy managed LiteLLM Proxy within your VPC.
|
||||||
|
|
||||||
|
Includes all enterprise features.
|
||||||
|
|
||||||
|
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
## Frequently Asked Questions
|
## Frequently Asked Questions
|
||||||
|
|
||||||
### What topics does Professional support cover and what SLAs do you offer?
|
### What topics does Professional support cover and what SLAs do you offer?
|
||||||
|
|
|
@ -13,7 +13,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
||||||
| >=500 | InternalServerError |
|
| >=500 | InternalServerError |
|
||||||
| N/A | ContextWindowExceededError|
|
| N/A | ContextWindowExceededError|
|
||||||
| 400 | ContentPolicyViolationError|
|
| 400 | ContentPolicyViolationError|
|
||||||
| N/A | APIConnectionError |
|
| 500 | APIConnectionError |
|
||||||
|
|
||||||
|
|
||||||
Base case we return APIConnectionError
|
Base case we return APIConnectionError
|
||||||
|
@ -74,6 +74,28 @@ except Exception as e:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage - Should you retry exception?
|
||||||
|
|
||||||
|
```
|
||||||
|
import litellm
|
||||||
|
import openai
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello, write a 20 pageg essay"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
timeout=0.01, # this will raise a timeout exception
|
||||||
|
)
|
||||||
|
except openai.APITimeoutError as e:
|
||||||
|
should_retry = litellm._should_retry(e.status_code)
|
||||||
|
print(f"should_retry: {should_retry}")
|
||||||
|
```
|
||||||
|
|
||||||
## Details
|
## Details
|
||||||
|
|
||||||
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
|
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
|
||||||
|
@ -86,21 +108,34 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
||||||
|
|
||||||
Base case - we return the original exception.
|
Base case - we return the original exception.
|
||||||
|
|
||||||
| | ContextWindowExceededError | AuthenticationError | InvalidRequestError | RateLimitError | ServiceUnavailableError |
|
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||||
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
|
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||||
| Anthropic | ✅ | ✅ | ✅ | ✅ | |
|
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||||
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| anthropic | ✓ | ✓ | ✓ | ✓ | | ✓ | | | ✓ | ✓ | |
|
||||||
| Huggingface | ✅ | ✅ | ✅ | ✅ | |
|
| replicate | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | | |
|
||||||
| Openrouter | ✅ | ✅ | ✅ | ✅ | |
|
| bedrock | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | ✓ | |
|
||||||
| AI21 | ✅ | ✅ | ✅ | ✅ | |
|
| sagemaker | | ✓ | ✓ | | | | | | | | |
|
||||||
| VertexAI | | |✅ | | |
|
| vertex_ai | ✓ | | ✓ | | | | ✓ | | | | ✓ |
|
||||||
| Bedrock | | |✅ | | |
|
| palm | ✓ | ✓ | | | | | ✓ | | | | |
|
||||||
| Sagemaker | | |✅ | | |
|
| gemini | ✓ | ✓ | | | | | ✓ | | | | |
|
||||||
| TogetherAI | ✅ | ✅ | ✅ | ✅ | |
|
| cloudflare | | | ✓ | | | ✓ | | | | | |
|
||||||
| AlephAlpha | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| cohere | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||||
|
| cohere_chat | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||||
|
| huggingface | ✓ | ✓ | ✓ | | | ✓ | | ✓ | ✓ | | |
|
||||||
|
| ai21 | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | | | |
|
||||||
|
| nlp_cloud | ✓ | ✓ | ✓ | | | ✓ | ✓ | ✓ | ✓ | | |
|
||||||
|
| together_ai | ✓ | ✓ | ✓ | | | ✓ | | | | | |
|
||||||
|
| aleph_alpha | | | ✓ | | | ✓ | | | | | |
|
||||||
|
| ollama | ✓ | | ✓ | | | | | | ✓ | | |
|
||||||
|
| ollama_chat | ✓ | | ✓ | | | | | | ✓ | | |
|
||||||
|
| vllm | | | | | | ✓ | ✓ | | | | |
|
||||||
|
| azure | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | | ✓ | | |
|
||||||
|
|
||||||
|
- "✓" indicates that the specified `custom_llm_provider` can raise the corresponding exception.
|
||||||
|
- Empty cells indicate the lack of association or that the provider does not raise that particular exception type as indicated by the function.
|
||||||
|
|
||||||
|
|
||||||
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.
|
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.
|
||||||
|
|
|
@ -46,4 +46,13 @@ Pricing is based on usage. We can figure out a price that works for your team, o
|
||||||
|
|
||||||
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
||||||
|
|
||||||
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
## Feature List
|
||||||
|
|
||||||
|
- Easy way to add/remove models
|
||||||
|
- 100% uptime even when models are added/removed
|
||||||
|
- custom callback webhooks
|
||||||
|
- your domain name with HTTPS
|
||||||
|
- Ability to create/delete User API keys
|
||||||
|
- Reasonable set monthly cost
|
|
@ -14,14 +14,14 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['OPENAI_API_KEY'] = ""
|
os.environ['OPENAI_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="gpt-3.5-turbo")
|
chat = ChatLiteLLM(model="gpt-3.5-turbo")
|
||||||
|
@ -30,7 +30,7 @@ messages = [
|
||||||
content="what model are you"
|
content="what model are you"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -39,14 +39,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['ANTHROPIC_API_KEY'] = ""
|
os.environ['ANTHROPIC_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="claude-2", temperature=0.3)
|
chat = ChatLiteLLM(model="claude-2", temperature=0.3)
|
||||||
|
@ -55,7 +55,7 @@ messages = [
|
||||||
content="what model are you"
|
content="what model are you"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -64,14 +64,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['REPLICATE_API_TOKEN'] = ""
|
os.environ['REPLICATE_API_TOKEN'] = ""
|
||||||
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
|
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
|
||||||
|
@ -80,7 +80,7 @@ messages = [
|
||||||
content="what model are you?"
|
content="what model are you?"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -89,14 +89,14 @@ chat(messages)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain.prompts.chat import (
|
from langchain_core.prompts import (
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
os.environ['COHERE_API_KEY'] = ""
|
os.environ['COHERE_API_KEY'] = ""
|
||||||
chat = ChatLiteLLM(model="command-nightly")
|
chat = ChatLiteLLM(model="command-nightly")
|
||||||
|
@ -105,32 +105,9 @@ messages = [
|
||||||
content="what model are you?"
|
content="what model are you?"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
chat(messages)
|
chat.invoke(messages)
|
||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="palm" label="PaLM - Google">
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
from langchain.chat_models import ChatLiteLLM
|
|
||||||
from langchain.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
SystemMessagePromptTemplate,
|
|
||||||
AIMessagePromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
os.environ['PALM_API_KEY'] = ""
|
|
||||||
chat = ChatLiteLLM(model="palm/chat-bison")
|
|
||||||
messages = [
|
|
||||||
HumanMessage(
|
|
||||||
content="what model are you?"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
chat(messages)
|
|
||||||
```
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
|
@ -94,9 +94,10 @@ print(response)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Set Custom Trace ID, Trace User ID and Tags
|
### Set Custom Trace ID, Trace User ID, Trace Metadata, Trace Version, Trace Release and Tags
|
||||||
|
|
||||||
|
Pass `trace_id`, `trace_user_id`, `trace_metadata`, `trace_version`, `trace_release`, `tags` in `metadata`
|
||||||
|
|
||||||
Pass `trace_id`, `trace_user_id` in `metadata`
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -121,12 +122,20 @@ response = completion(
|
||||||
metadata={
|
metadata={
|
||||||
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
||||||
"generation_id": "gen-id22", # set langfuse Generation ID
|
"generation_id": "gen-id22", # set langfuse Generation ID
|
||||||
|
"version": "test-generation-version" # set langfuse Generation Version
|
||||||
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
||||||
"session_id": "session-1", # set langfuse Session ID
|
"session_id": "session-1", # set langfuse Session ID
|
||||||
"tags": ["tag1", "tag2"] # set langfuse Tags
|
"tags": ["tag1", "tag2"], # set langfuse Tags
|
||||||
"trace_id": "trace-id22", # set langfuse Trace ID
|
"trace_id": "trace-id22", # set langfuse Trace ID
|
||||||
|
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
|
||||||
|
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
|
||||||
|
"trace_release": "test-trace-release", # set langfuse Trace Release
|
||||||
### OR ###
|
### OR ###
|
||||||
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
||||||
|
### OR enforce that certain fields are trace overwritten in the trace during the continuation ###
|
||||||
|
"existing_trace_id": "trace-id22",
|
||||||
|
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
|
||||||
|
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -134,6 +143,38 @@ print(response)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Trace & Generation Parameters
|
||||||
|
|
||||||
|
#### Trace Specific Parameters
|
||||||
|
|
||||||
|
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default
|
||||||
|
* `trace_name` - Name of the trace, auto-generated by default
|
||||||
|
* `session_id` - Session identifier for the trace, defaults to `None`
|
||||||
|
* `trace_version` - Version for the trace, defaults to value for `version`
|
||||||
|
* `trace_release` - Release for the trace, defaults to `None`
|
||||||
|
* `trace_metadata` - Metadata for the trace, defaults to `None`
|
||||||
|
* `trace_user_id` - User identifier for the trace, defaults to completion argument `user`
|
||||||
|
* `tags` - Tags for the trace, defeaults to `None`
|
||||||
|
|
||||||
|
##### Updatable Parameters on Continuation
|
||||||
|
|
||||||
|
The following parameters can be updated on a continuation of a trace by passing in the following values into the `update_trace_keys` in the metadata of the completion.
|
||||||
|
|
||||||
|
* `input` - Will set the traces input to be the input of this latest generation
|
||||||
|
* `output` - Will set the traces output to be the output of this generation
|
||||||
|
* `trace_version` - Will set the trace version to be the provided value (To use the latest generations version instead, use `version`)
|
||||||
|
* `trace_release` - Will set the trace release to be the provided value
|
||||||
|
* `trace_metadata` - Will set the trace metadata to the provided value
|
||||||
|
* `trace_user_id` - Will set the trace user id to the provided value
|
||||||
|
|
||||||
|
#### Generation Specific Parameters
|
||||||
|
|
||||||
|
* `generation_id` - Identifier for the generation, auto-generated by default
|
||||||
|
* `generation_name` - Identifier for the generation, auto-generated by default
|
||||||
|
* `prompt` - Langfuse prompt object used for the generation, defaults to None
|
||||||
|
|
||||||
|
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
|
||||||
|
|
||||||
### Use LangChain ChatLiteLLM + Langfuse
|
### Use LangChain ChatLiteLLM + Langfuse
|
||||||
Pass `trace_user_id`, `session_id` in model_kwargs
|
Pass `trace_user_id`, `session_id` in model_kwargs
|
||||||
```python
|
```python
|
||||||
|
|
|
@ -535,7 +535,8 @@ print(response)
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|----------------------|---------------------------------------------|
|
|----------------------|---------------------------------------------|
|
||||||
| Titan Embeddings - G1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` |
|
||||||
|
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
||||||
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
|
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
|
||||||
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |
|
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |
|
||||||
|
|
||||||
|
|
54
docs/my-website/docs/providers/deepseek.md
Normal file
54
docs/my-website/docs/providers/deepseek.md
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Deepseek
|
||||||
|
https://deepseek.com/
|
||||||
|
|
||||||
|
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['DEEPSEEK_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage - Streaming
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="deepseek/deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Models - ALL Deepseek Models Supported!
|
||||||
|
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||||
|
|
||||||
|
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -40,9 +45,58 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: wizard-coder
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "wizard-coder",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||||
|
|
||||||
|
Append `conversational` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/conversational/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://my-endpoint.huggingface.cloud"
|
api_base="https://my-endpoint.huggingface.cloud"
|
||||||
)
|
)
|
||||||
|
@ -62,7 +116,123 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: blenderbot
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "blenderbot",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="classification" label="Text Classification">
|
||||||
|
|
||||||
|
Append `text-classification` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-classification/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# [OPTIONAL] set env var
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
|
|
||||||
|
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||||
|
|
||||||
|
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "bert-classifier",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||||
|
|
||||||
|
Append `text-generation` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-generation/<model-name>`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/roneneldan/TinyStories-3M",
|
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,14 +44,14 @@ for chunk in response:
|
||||||
## Supported Models
|
## Supported Models
|
||||||
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|----------------|--------------------------------------------------------------|
|
||||||
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
|
| Mistral Small | `completion(model="mistral/mistral-small-latest", messages)` |
|
||||||
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
| Mistral Medium | `completion(model="mistral/mistral-medium-latest", messages)`|
|
||||||
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
| Mistral Large | `completion(model="mistral/mistral-large-latest", messages)` |
|
||||||
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
|
| Mistral 7B | `completion(model="mistral/open-mistral-7b", messages)` |
|
||||||
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
| Mixtral 8x7B | `completion(model="mistral/open-mixtral-8x7b", messages)` |
|
||||||
|
| Mixtral 8x22B | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||||
|
|
||||||
## Function Calling
|
## Function Calling
|
||||||
|
|
||||||
|
@ -116,6 +116,6 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| mistral-embed | `embedding(model="mistral/mistral-embed", input)` |
|
| Mistral Embeddings | `embedding(model="mistral/mistral-embed", input)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
247
docs/my-website/docs/providers/predibase.md
Normal file
247
docs/my-website/docs/providers/predibase.md
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 🆕 Predibase
|
||||||
|
|
||||||
|
LiteLLM supports all models on Predibase
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
### API KEYS
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
|
||||||
|
|
||||||
|
# predibase llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||||
|
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="llama-3",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "llama-3",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Be a good human!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know about earth?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Advanced Usage - Prompt Formatting
|
||||||
|
|
||||||
|
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
|
||||||
|
|
||||||
|
To apply a custom prompt template:
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = ""
|
||||||
|
|
||||||
|
# Create your own custom prompt template
|
||||||
|
litellm.register_prompt_template(
|
||||||
|
model="togethercomputer/LLaMA-2-7B-32K",
|
||||||
|
initial_prompt_value="You are a good assistant" # [OPTIONAL]
|
||||||
|
roles={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||||
|
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "[INST] ", # [OPTIONAL]
|
||||||
|
"post_message": " [/INST]" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "\n" # [OPTIONAL]
|
||||||
|
"post_message": "\n" # [OPTIONAL]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
def predibase_custom_model():
|
||||||
|
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
|
||||||
|
response = completion(model=model, messages=messages)
|
||||||
|
print(response['choices'][0]['message']['content'])
|
||||||
|
return response
|
||||||
|
|
||||||
|
predibase_custom_model()
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Model-specific parameters
|
||||||
|
model_list:
|
||||||
|
- model_name: mistral-7b # model alias
|
||||||
|
litellm_params: # actual params for litellm.completion()
|
||||||
|
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
initial_prompt_value: "\n"
|
||||||
|
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||||
|
final_prompt_value: "\n"
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
max_tokens: 4096
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Passing additional params - max_tokens, temperature
|
||||||
|
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibae llama-3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.5
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
max_tokens: 20
|
||||||
|
temperature: 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Passings Predibase specific params - adapter_id, adapter_source,
|
||||||
|
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
|
||||||
|
|
||||||
|
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# !pip install litellm
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
## set ENV variables
|
||||||
|
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||||
|
|
||||||
|
# predibase llama3 call
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||||
|
adapter_id="my_repo/3",
|
||||||
|
adapter_soruce="pbase",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**proxy**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: llama-3
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
adapter_id: my_repo/3
|
||||||
|
adapter_source: pbase
|
||||||
|
```
|
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Triton Inference Server
|
||||||
|
|
||||||
|
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
|
||||||
|
### Example Call
|
||||||
|
|
||||||
|
Use the `triton/` prefix to route to triton server
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="triton/<your-triton-model>",
|
||||||
|
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: my-triton-model
|
||||||
|
litellm_params:
|
||||||
|
model: triton/<your-triton-model>"
|
||||||
|
api_base: https://your-triton-api-base/triton/embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Send Request to LiteLLM Proxy Server
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=["hello from litellm"],
|
||||||
|
model="my-triton-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
|
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data ' {
|
||||||
|
"model": "my-triton-model",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
|
@ -477,6 +477,36 @@ print(response)
|
||||||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Embedding Models
|
||||||
|
|
||||||
|
#### Usage - Embedding
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding
|
||||||
|
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||||
|
litellm.vertex_location = "us-central1" # proj location
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
model="vertex_ai/textembedding-gecko",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supported Embedding Models
|
||||||
|
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
|
||||||
|
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
|
||||||
|
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
||||||
|
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
|
||||||
|
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||||
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
|
||||||
## Extra
|
## Extra
|
||||||
|
|
||||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||||
|
@ -520,6 +550,12 @@ def load_vertex_ai_credentials():
|
||||||
|
|
||||||
### Using GCP Service Account
|
### Using GCP Service Account
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Trying to deploy LiteLLM on Google Cloud Run? Tutorial [here](https://docs.litellm.ai/docs/proxy/deploy#deploy-on-google-cloud-run)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
1. Figure out the Service Account bound to the Google Cloud Run service
|
1. Figure out the Service Account bound to the Google Cloud Run service
|
||||||
|
|
||||||
<Image img={require('../../img/gcp_acc_1.png')} />
|
<Image img={require('../../img/gcp_acc_1.png')} />
|
||||||
|
|
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Region-based Routing
|
||||||
|
|
||||||
|
Route specific customers to eu-only models.
|
||||||
|
|
||||||
|
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
|
||||||
|
|
||||||
|
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
|
||||||
|
|
||||||
|
### 1. Create customer with region-specification
|
||||||
|
|
||||||
|
Use the litellm 'end-user' object for this.
|
||||||
|
|
||||||
|
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"user_id" : "ishaan-jaff-45",
|
||||||
|
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Add eu models to model-group
|
||||||
|
|
||||||
|
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the proxy
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Test it!
|
||||||
|
|
||||||
|
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the meaning of the universe? 1234"
|
||||||
|
}],
|
||||||
|
"user": "ishaan-jaff-45" # 👈 USER ID
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected API Base in response headers
|
||||||
|
|
||||||
|
```
|
||||||
|
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||||
|
```
|
||||||
|
|
||||||
|
### FAQ
|
||||||
|
|
||||||
|
**What happens if there are no available models for that region?**
|
||||||
|
|
||||||
|
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.
|
|
@ -915,39 +915,72 @@ Test Request
|
||||||
litellm --test
|
litellm --test
|
||||||
```
|
```
|
||||||
|
|
||||||
## Logging Proxy Input/Output Traceloop (OpenTelemetry)
|
## Logging Proxy Input/Output in OpenTelemetry format using Traceloop's OpenLLMetry
|
||||||
|
|
||||||
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format
|
[OpenLLMetry](https://github.com/traceloop/openllmetry) _(built and maintained by Traceloop)_ is a set of extensions
|
||||||
|
built on top of [OpenTelemetry](https://opentelemetry.io/) that gives you complete observability over your LLM
|
||||||
|
application. Because it uses OpenTelemetry under the
|
||||||
|
hood, [it can be connected to various observability solutions](https://www.traceloop.com/docs/openllmetry/integrations/introduction)
|
||||||
|
like:
|
||||||
|
|
||||||
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop
|
* [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop)
|
||||||
|
* [Axiom](https://www.traceloop.com/docs/openllmetry/integrations/axiom)
|
||||||
|
* [Azure Application Insights](https://www.traceloop.com/docs/openllmetry/integrations/azure)
|
||||||
|
* [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
||||||
|
* [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace)
|
||||||
|
* [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana)
|
||||||
|
* [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb)
|
||||||
|
* [HyperDX](https://www.traceloop.com/docs/openllmetry/integrations/hyperdx)
|
||||||
|
* [Instana](https://www.traceloop.com/docs/openllmetry/integrations/instana)
|
||||||
|
* [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic)
|
||||||
|
* [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
||||||
|
* [Service Now Cloud Observability](https://www.traceloop.com/docs/openllmetry/integrations/service-now)
|
||||||
|
* [Sentry](https://www.traceloop.com/docs/openllmetry/integrations/sentry)
|
||||||
|
* [SigNoz](https://www.traceloop.com/docs/openllmetry/integrations/signoz)
|
||||||
|
* [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk)
|
||||||
|
|
||||||
**Step 1** Install traceloop-sdk and set Traceloop API key
|
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` to achieve this, steps are listed below.
|
||||||
|
|
||||||
|
**Step 1:** Install the SDK
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install traceloop-sdk -U
|
pip install traceloop-sdk
|
||||||
```
|
```
|
||||||
|
|
||||||
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
**Step 2:** Configure Environment Variable for trace exporting
|
||||||
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
|
||||||
|
You will need to configure where to export your traces. Environment variables will control this, example: For Traceloop
|
||||||
|
you should use `TRACELOOP_API_KEY`, whereas for Datadog you use `TRACELOOP_BASE_URL`. For more
|
||||||
|
visit [the Integrations Catalog](https://www.traceloop.com/docs/openllmetry/integrations/introduction).
|
||||||
|
|
||||||
|
If you are using Datadog as the observability solutions then you can set `TRACELOOP_BASE_URL` as:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
TRACELOOP_BASE_URL=http://<datadog-agent-hostname>:4318
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||||
|
|
||||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
|
api_key: my-fake-key # replace api_key with actual key
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["traceloop"]
|
success_callback: [ "traceloop" ]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Step 3**: Start the proxy, make a test request
|
**Step 4**: Start the proxy, make a test request
|
||||||
|
|
||||||
Start proxy
|
Start proxy
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
litellm --config config.yaml --debug
|
litellm --config config.yaml --debug
|
||||||
```
|
```
|
||||||
|
|
||||||
Test Request
|
Test Request
|
||||||
|
|
||||||
```
|
```
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
|
|
|
@ -3,34 +3,38 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# ⚡ Best Practices for Production
|
# ⚡ Best Practices for Production
|
||||||
|
|
||||||
Expected Performance in Production
|
## 1. Use this config.yaml
|
||||||
|
Use this config.yaml in production (with your own LLMs)
|
||||||
|
|
||||||
1 LiteLLM Uvicorn Worker on Kubernetes
|
|
||||||
|
|
||||||
| Description | Value |
|
|
||||||
|--------------|-------|
|
|
||||||
| Avg latency | `50ms` |
|
|
||||||
| Median latency | `51ms` |
|
|
||||||
| `/chat/completions` Requests/second | `35` |
|
|
||||||
| `/chat/completions` Requests/minute | `2100` |
|
|
||||||
| `/chat/completions` Requests/hour | `126K` |
|
|
||||||
|
|
||||||
|
|
||||||
## 1. Switch off Debug Logging
|
|
||||||
|
|
||||||
Remove `set_verbose: True` from your config.yaml
|
|
||||||
```yaml
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # enter your own master key, ensure it starts with 'sk-'
|
||||||
|
alerting: ["slack"] # Setup slack alerting - get alerts on LLM exceptions, Budget Alerts, Slow LLM Responses
|
||||||
|
proxy_batch_write_at: 60 # Batch write spend updates every 60s
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
set_verbose: True
|
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||||
```
|
```
|
||||||
|
|
||||||
You should only see the following level of details in logs on the proxy server
|
Set slack webhook url in your env
|
||||||
```shell
|
```shell
|
||||||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/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"
|
||||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
|
||||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Need Help or want dedicated support ? Talk to a founder [here]: (https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
|
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
|
||||||
|
|
||||||
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
||||||
|
@ -40,21 +44,12 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
||||||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||||
```
|
```
|
||||||
|
|
||||||
## 3. Batch write spend updates every 60s
|
|
||||||
|
|
||||||
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
|
## 3. Use Redis 'port','host', 'password'. NOT 'redis_url'
|
||||||
|
|
||||||
In production, we recommend using a longer interval period of 60s. This reduces the number of connections used to make DB writes.
|
If you decide to use Redis, DO NOT use 'redis_url'. We recommend usig redis port, host, and password params.
|
||||||
|
|
||||||
```yaml
|
`redis_url`is 80 RPS slower
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
|
|
||||||
|
|
||||||
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
|
|
||||||
|
|
||||||
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
||||||
|
|
||||||
|
@ -69,103 +64,31 @@ router_settings:
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
```
|
```
|
||||||
|
|
||||||
## 5. Switch off resetting budgets
|
## Extras
|
||||||
|
### Expected Performance in Production
|
||||||
|
|
||||||
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
|
1 LiteLLM Uvicorn Worker on Kubernetes
|
||||||
```yaml
|
|
||||||
general_settings:
|
|
||||||
disable_reset_budget: true
|
|
||||||
```
|
|
||||||
|
|
||||||
## 6. Move spend logs to separate server (BETA)
|
| Description | Value |
|
||||||
|
|--------------|-------|
|
||||||
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
|
| Avg latency | `50ms` |
|
||||||
|
| Median latency | `51ms` |
|
||||||
👉 [LiteLLM Spend Logs Server](https://github.com/BerriAI/litellm/tree/main/litellm-js/spend-logs)
|
| `/chat/completions` Requests/second | `35` |
|
||||||
|
| `/chat/completions` Requests/minute | `2100` |
|
||||||
|
| `/chat/completions` Requests/hour | `126K` |
|
||||||
|
|
||||||
|
|
||||||
**Spend Logs**
|
### Verifying Debugging logs are off
|
||||||
This is a log of the key, tokens, model, and latency for each call on the proxy.
|
|
||||||
|
|
||||||
[**Full Payload**](https://github.com/BerriAI/litellm/blob/8c9623a6bc4ad9da0a2dac64249a60ed8da719e8/litellm/proxy/utils.py#L1769)
|
You should only see the following level of details in logs on the proxy server
|
||||||
|
```shell
|
||||||
|
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
**1. Start the spend logs server**
|
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
|
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||||
```bash
|
|
||||||
docker run -p 3000:3000 \
|
|
||||||
-e DATABASE_URL="postgres://.." \
|
|
||||||
ghcr.io/berriai/litellm-spend_logs:main-latest
|
|
||||||
|
|
||||||
# RUNNING on http://0.0.0.0:3000
|
|
||||||
```
|
|
||||||
|
|
||||||
**2. Connect to proxy**
|
|
||||||
|
|
||||||
|
|
||||||
Example litellm_config.yaml
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: fake-openai-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/my-fake-model
|
|
||||||
api_key: my-fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
master_key: sk-1234
|
|
||||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
|
||||||
```
|
|
||||||
|
|
||||||
Add `SPEND_LOGS_URL` as an environment variable when starting the proxy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
|
|
||||||
-e DATABASE_URL="postgresql://.." \
|
|
||||||
-e SPEND_LOGS_URL="http://host.docker.internal:3000" \ # 👈 KEY CHANGE
|
|
||||||
-p 4000:4000 \
|
|
||||||
ghcr.io/berriai/litellm:main-latest \
|
|
||||||
--config /app/config.yaml --detailed_debug
|
|
||||||
|
|
||||||
# Running on http://0.0.0.0:4000
|
|
||||||
```
|
|
||||||
|
|
||||||
**3. Test Proxy!**
|
|
||||||
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--header 'Authorization: Bearer sk-1234' \
|
|
||||||
--data '{
|
|
||||||
"model": "fake-openai-endpoint",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "Be helpful"},
|
|
||||||
{"role": "user", "content": "What do you know?"}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
In your LiteLLM Spend Logs Server, you should see
|
|
||||||
|
|
||||||
**Expected Response**
|
|
||||||
|
|
||||||
```
|
|
||||||
Received and stored 1 logs. Total logs in memory: 1
|
|
||||||
...
|
|
||||||
Flushed 1 log to the DB.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Machine Specification
|
### Machine Specifications to Deploy LiteLLM
|
||||||
|
|
||||||
A t2.micro should be sufficient to handle 1k logs / minute on this server.
|
|
||||||
|
|
||||||
This consumes at max 120MB, and <0.1 vCPU.
|
|
||||||
|
|
||||||
## Machine Specifications to Deploy LiteLLM
|
|
||||||
|
|
||||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
| Service | Spec | CPUs | Memory | Architecture | Version|
|
||||||
| --- | --- | --- | --- | --- | --- |
|
| --- | --- | --- | --- | --- | --- |
|
||||||
|
@ -173,7 +96,7 @@ This consumes at max 120MB, and <0.1 vCPU.
|
||||||
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
||||||
|
|
||||||
|
|
||||||
## Reference Kubernetes Deployment YAML
|
### Reference Kubernetes Deployment YAML
|
||||||
|
|
||||||
Reference Kubernetes `deployment.yaml` that was load tested by us
|
Reference Kubernetes `deployment.yaml` that was load tested by us
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
|
||||||
### Step 1. Setup Proxy
|
### Step 1. Setup Proxy
|
||||||
|
|
||||||
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||||
|
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
||||||
|
|
|
@ -12,8 +12,8 @@ Requirements:
|
||||||
|
|
||||||
You can set budgets at 3 levels:
|
You can set budgets at 3 levels:
|
||||||
- For the proxy
|
- For the proxy
|
||||||
- For a user
|
- For an internal user
|
||||||
- For a 'user' passed to `/chat/completions`, `/embeddings` etc
|
- For an end-user
|
||||||
- For a key
|
- For a key
|
||||||
- For a key (model specific budgets)
|
- For a key (model specific budgets)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="per-user" label="For User">
|
<TabItem value="per-user" label="For Internal User">
|
||||||
|
|
||||||
Apply a budget across multiple keys.
|
Apply a budget across multiple keys.
|
||||||
|
|
||||||
|
@ -165,12 +165,12 @@ curl --location 'http://localhost:4000/team/new' \
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="per-user-chat" label="For 'user' passed to /chat/completions">
|
<TabItem value="per-user-chat" label="For End User">
|
||||||
|
|
||||||
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
|
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
|
||||||
|
|
||||||
**Step 1. Modify config.yaml**
|
**Step 1. Modify config.yaml**
|
||||||
Define `litellm.max_user_budget`
|
Define `litellm.max_end_user_budget`
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
@ -328,7 +328,7 @@ You can set:
|
||||||
- max parallel requests
|
- max parallel requests
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="per-user" label="Per User">
|
<TabItem value="per-user" label="Per Internal User">
|
||||||
|
|
||||||
Use `/user/new`, to persist rate limits across multiple keys.
|
Use `/user/new`, to persist rate limits across multiple keys.
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ curl --location 'http://localhost:4000/user/new' \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Create new keys for existing user
|
## Create new keys for existing internal user
|
||||||
|
|
||||||
Just include user_id in the `/key/generate` request.
|
Just include user_id in the `/key/generate` request.
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ print(response)
|
||||||
- `router.aimage_generation()` - async image generation calls
|
- `router.aimage_generation()` - async image generation calls
|
||||||
|
|
||||||
## Advanced - Routing Strategies
|
## Advanced - Routing Strategies
|
||||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based, Cost Based
|
||||||
|
|
||||||
Router provides 4 strategies for routing your calls across multiple deployments:
|
Router provides 4 strategies for routing your calls across multiple deployments:
|
||||||
|
|
||||||
|
@ -467,6 +467,101 @@ async def router_acompletion():
|
||||||
asyncio.run(router_acompletion())
|
asyncio.run(router_acompletion())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
|
||||||
|
|
||||||
|
Picks a deployment based on the lowest cost
|
||||||
|
|
||||||
|
How this works:
|
||||||
|
- Get all healthy deployments
|
||||||
|
- Select all deployments that are under their provided `rpm/tpm` limits
|
||||||
|
- For each deployment check if `litellm_param["model"]` exists in [`litellm_model_cost_map`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||||
|
- if deployment does not exist in `litellm_model_cost_map` -> use deployment_cost= `$1`
|
||||||
|
- Select deployment with lowest cost
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-4"},
|
||||||
|
"model_info": {"id": "openai-gpt-4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "groq/llama3-8b-8192"},
|
||||||
|
"model_info": {"id": "groq-llama"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# init router
|
||||||
|
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||||
|
async def router_acompletion():
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
print(response._hidden_params["model_id"]) # expect groq-llama, since groq/llama has lowest cost
|
||||||
|
return response
|
||||||
|
|
||||||
|
asyncio.run(router_acompletion())
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Using Custom Input/Output pricing
|
||||||
|
|
||||||
|
Set `litellm_params["input_cost_per_token"]` and `litellm_params["output_cost_per_token"]` for using custom pricing when routing
|
||||||
|
|
||||||
|
```python
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"input_cost_per_token": 0.00003,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-experimental"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-1",
|
||||||
|
"input_cost_per_token": 0.000000001,
|
||||||
|
"output_cost_per_token": 0.00000001,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-5",
|
||||||
|
"input_cost_per_token": 10,
|
||||||
|
"output_cost_per_token": 12,
|
||||||
|
},
|
||||||
|
"model_info": {"id": "chatgpt-v-5"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
# init router
|
||||||
|
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||||
|
async def router_acompletion():
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
print(response._hidden_params["model_id"]) # expect chatgpt-v-1, since chatgpt-v-1 has lowest cost
|
||||||
|
return response
|
||||||
|
|
||||||
|
asyncio.run(router_acompletion())
|
||||||
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
@ -616,6 +711,57 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Retries based on Error Type
|
||||||
|
|
||||||
|
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- 4 retries for `ContentPolicyViolationError`
|
||||||
|
- 0 retries for `RateLimitErrors`
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.router import RetryPolicy
|
||||||
|
retry_policy = RetryPolicy(
|
||||||
|
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||||
|
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||||
|
BadRequestErrorRetries=1,
|
||||||
|
TimeoutErrorRetries=2,
|
||||||
|
RateLimitErrorRetries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "bad-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
retry_policy=retry_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Fallbacks
|
### Fallbacks
|
||||||
|
|
||||||
If a call fails after num_retries, fall back to another model group.
|
If a call fails after num_retries, fall back to another model group.
|
||||||
|
@ -940,6 +1086,46 @@ async def test_acompletion_caching_on_router_caching_groups():
|
||||||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Alerting 🚨
|
||||||
|
|
||||||
|
Send alerts to slack / your webhook url for the following events
|
||||||
|
- LLM API Exceptions
|
||||||
|
- Slow LLM Responses
|
||||||
|
|
||||||
|
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.router import AlertingConfig
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "bad_key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
alerting_config= AlertingConfig(
|
||||||
|
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
|
||||||
|
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
## Track cost for Azure Deployments
|
## Track cost for Azure Deployments
|
||||||
|
|
||||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||||
|
@ -1108,6 +1294,7 @@ def __init__(
|
||||||
"least-busy",
|
"least-busy",
|
||||||
"usage-based-routing",
|
"usage-based-routing",
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
|
"cost-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
|
|
||||||
## DEBUGGING ##
|
## DEBUGGING ##
|
||||||
|
|
|
@ -50,6 +50,7 @@ const sidebars = {
|
||||||
items: ["proxy/logging", "proxy/streaming_logging"],
|
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||||
},
|
},
|
||||||
"proxy/team_based_routing",
|
"proxy/team_based_routing",
|
||||||
|
"proxy/customer_routing",
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/cost_tracking",
|
"proxy/cost_tracking",
|
||||||
"proxy/token_auth",
|
"proxy/token_auth",
|
||||||
|
@ -131,9 +132,13 @@ const sidebars = {
|
||||||
"providers/cohere",
|
"providers/cohere",
|
||||||
"providers/anyscale",
|
"providers/anyscale",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
|
"providers/watsonx",
|
||||||
|
"providers/predibase",
|
||||||
|
"providers/triton-inference-server",
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/perplexity",
|
"providers/perplexity",
|
||||||
"providers/groq",
|
"providers/groq",
|
||||||
|
"providers/deepseek",
|
||||||
"providers/fireworks_ai",
|
"providers/fireworks_ai",
|
||||||
"providers/vllm",
|
"providers/vllm",
|
||||||
"providers/xinference",
|
"providers/xinference",
|
||||||
|
@ -149,7 +154,7 @@ const sidebars = {
|
||||||
"providers/openrouter",
|
"providers/openrouter",
|
||||||
"providers/custom_openai_proxy",
|
"providers/custom_openai_proxy",
|
||||||
"providers/petals",
|
"providers/petals",
|
||||||
"providers/watsonx",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"proxy/custom_pricing",
|
"proxy/custom_pricing",
|
||||||
|
|
|
@ -291,7 +291,7 @@ def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
|
||||||
|
|
||||||
|
|
||||||
def _forecast_daily_cost(data: list):
|
def _forecast_daily_cost(data: list):
|
||||||
import requests
|
import requests # type: ignore
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
|
|
108
index.yaml
Normal file
108
index.yaml
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
apiVersion: v1
|
||||||
|
entries:
|
||||||
|
litellm-helm:
|
||||||
|
- apiVersion: v2
|
||||||
|
appVersion: v1.35.38
|
||||||
|
created: "2024-05-06T10:22:24.384392-07:00"
|
||||||
|
dependencies:
|
||||||
|
- condition: db.deployStandalone
|
||||||
|
name: postgresql
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
version: '>=13.3.0'
|
||||||
|
- condition: redis.enabled
|
||||||
|
name: redis
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
version: '>=18.0.0'
|
||||||
|
description: Call all LLM APIs using the OpenAI format
|
||||||
|
digest: 60f0cfe9e7c1087437cb35f6fb7c43c3ab2be557b6d3aec8295381eb0dfa760f
|
||||||
|
name: litellm-helm
|
||||||
|
type: application
|
||||||
|
urls:
|
||||||
|
- litellm-helm-0.2.0.tgz
|
||||||
|
version: 0.2.0
|
||||||
|
postgresql:
|
||||||
|
- annotations:
|
||||||
|
category: Database
|
||||||
|
images: |
|
||||||
|
- name: os-shell
|
||||||
|
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||||
|
- name: postgres-exporter
|
||||||
|
image: docker.io/bitnami/postgres-exporter:0.15.0-debian-12-r14
|
||||||
|
- name: postgresql
|
||||||
|
image: docker.io/bitnami/postgresql:16.2.0-debian-12-r6
|
||||||
|
licenses: Apache-2.0
|
||||||
|
apiVersion: v2
|
||||||
|
appVersion: 16.2.0
|
||||||
|
created: "2024-05-06T10:22:24.387717-07:00"
|
||||||
|
dependencies:
|
||||||
|
- name: common
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
tags:
|
||||||
|
- bitnami-common
|
||||||
|
version: 2.x.x
|
||||||
|
description: PostgreSQL (Postgres) is an open source object-relational database
|
||||||
|
known for reliability and data integrity. ACID-compliant, it supports foreign
|
||||||
|
keys, joins, views, triggers and stored procedures.
|
||||||
|
digest: 3c8125526b06833df32e2f626db34aeaedb29d38f03d15349db6604027d4a167
|
||||||
|
home: https://bitnami.com
|
||||||
|
icon: https://bitnami.com/assets/stacks/postgresql/img/postgresql-stack-220x234.png
|
||||||
|
keywords:
|
||||||
|
- postgresql
|
||||||
|
- postgres
|
||||||
|
- database
|
||||||
|
- sql
|
||||||
|
- replication
|
||||||
|
- cluster
|
||||||
|
maintainers:
|
||||||
|
- name: VMware, Inc.
|
||||||
|
url: https://github.com/bitnami/charts
|
||||||
|
name: postgresql
|
||||||
|
sources:
|
||||||
|
- https://github.com/bitnami/charts/tree/main/bitnami/postgresql
|
||||||
|
urls:
|
||||||
|
- charts/postgresql-14.3.1.tgz
|
||||||
|
version: 14.3.1
|
||||||
|
redis:
|
||||||
|
- annotations:
|
||||||
|
category: Database
|
||||||
|
images: |
|
||||||
|
- name: kubectl
|
||||||
|
image: docker.io/bitnami/kubectl:1.29.2-debian-12-r3
|
||||||
|
- name: os-shell
|
||||||
|
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||||
|
- name: redis
|
||||||
|
image: docker.io/bitnami/redis:7.2.4-debian-12-r9
|
||||||
|
- name: redis-exporter
|
||||||
|
image: docker.io/bitnami/redis-exporter:1.58.0-debian-12-r4
|
||||||
|
- name: redis-sentinel
|
||||||
|
image: docker.io/bitnami/redis-sentinel:7.2.4-debian-12-r7
|
||||||
|
licenses: Apache-2.0
|
||||||
|
apiVersion: v2
|
||||||
|
appVersion: 7.2.4
|
||||||
|
created: "2024-05-06T10:22:24.391903-07:00"
|
||||||
|
dependencies:
|
||||||
|
- name: common
|
||||||
|
repository: oci://registry-1.docker.io/bitnamicharts
|
||||||
|
tags:
|
||||||
|
- bitnami-common
|
||||||
|
version: 2.x.x
|
||||||
|
description: Redis(R) is an open source, advanced key-value store. It is often
|
||||||
|
referred to as a data structure server since keys can contain strings, hashes,
|
||||||
|
lists, sets and sorted sets.
|
||||||
|
digest: b2fa1835f673a18002ca864c54fadac3c33789b26f6c5e58e2851b0b14a8f984
|
||||||
|
home: https://bitnami.com
|
||||||
|
icon: https://bitnami.com/assets/stacks/redis/img/redis-stack-220x234.png
|
||||||
|
keywords:
|
||||||
|
- redis
|
||||||
|
- keyvalue
|
||||||
|
- database
|
||||||
|
maintainers:
|
||||||
|
- name: VMware, Inc.
|
||||||
|
url: https://github.com/bitnami/charts
|
||||||
|
name: redis
|
||||||
|
sources:
|
||||||
|
- https://github.com/bitnami/charts/tree/main/bitnami/redis
|
||||||
|
urls:
|
||||||
|
- charts/redis-18.19.1.tgz
|
||||||
|
version: 18.19.1
|
||||||
|
generated: "2024-05-06T10:22:24.375026-07:00"
|
BIN
litellm-helm-0.2.0.tgz
Normal file
BIN
litellm-helm-0.2.0.tgz
Normal file
Binary file not shown.
|
@ -1,3 +1,7 @@
|
||||||
|
### Hide pydantic namespace conflict warnings globally ###
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
|
@ -71,9 +75,11 @@ maritalk_key: Optional[str] = None
|
||||||
ai21_key: Optional[str] = None
|
ai21_key: Optional[str] = None
|
||||||
ollama_key: Optional[str] = None
|
ollama_key: Optional[str] = None
|
||||||
openrouter_key: Optional[str] = None
|
openrouter_key: Optional[str] = None
|
||||||
|
predibase_key: Optional[str] = None
|
||||||
huggingface_key: Optional[str] = None
|
huggingface_key: Optional[str] = None
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
|
predibase_tenant_id: Optional[str] = None
|
||||||
togetherai_api_key: Optional[str] = None
|
togetherai_api_key: Optional[str] = None
|
||||||
cloudflare_api_key: Optional[str] = None
|
cloudflare_api_key: Optional[str] = None
|
||||||
baseten_key: Optional[str] = None
|
baseten_key: Optional[str] = None
|
||||||
|
@ -361,6 +367,7 @@ openai_compatible_endpoints: List = [
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
"api.groq.com/openai/v1",
|
"api.groq.com/openai/v1",
|
||||||
|
"api.deepseek.com/v1",
|
||||||
"api.together.xyz/v1",
|
"api.together.xyz/v1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -369,6 +376,7 @@ openai_compatible_providers: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"xinference",
|
"xinference",
|
||||||
|
@ -523,12 +531,15 @@ provider_list: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"deepseek",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"voyage",
|
"voyage",
|
||||||
"cloudflare",
|
"cloudflare",
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
"triton",
|
||||||
|
"predibase",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -605,7 +616,6 @@ all_embedding_models = (
|
||||||
####### IMAGE GENERATION MODELS ###################
|
####### IMAGE GENERATION MODELS ###################
|
||||||
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||||
|
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .utils import (
|
from .utils import (
|
||||||
client,
|
client,
|
||||||
|
@ -613,6 +623,8 @@ from .utils import (
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
modify_integration,
|
modify_integration,
|
||||||
token_counter,
|
token_counter,
|
||||||
|
create_pretrained_tokenizer,
|
||||||
|
create_tokenizer,
|
||||||
cost_per_token,
|
cost_per_token,
|
||||||
completion_cost,
|
completion_cost,
|
||||||
supports_function_calling,
|
supports_function_calling,
|
||||||
|
@ -636,9 +648,11 @@ from .utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
get_supported_openai_params,
|
get_supported_openai_params,
|
||||||
get_api_base,
|
get_api_base,
|
||||||
|
get_first_chars_messages,
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
from .llms.predibase import PredibaseConfig
|
||||||
from .llms.anthropic_text import AnthropicTextConfig
|
from .llms.anthropic_text import AnthropicTextConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere import CohereConfig
|
from .llms.cohere import CohereConfig
|
||||||
|
@ -692,3 +706,4 @@ from .exceptions import (
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
from .router import Router
|
from .router import Router
|
||||||
|
from .assistants.main import *
|
||||||
|
|
|
@ -10,8 +10,8 @@
|
||||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
import redis, litellm
|
import redis, litellm # type: ignore
|
||||||
import redis.asyncio as async_redis
|
import redis.asyncio as async_redis # type: ignore
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
|
495
litellm/assistants/main.py
Normal file
495
litellm/assistants/main.py
Normal file
|
@ -0,0 +1,495 @@
|
||||||
|
# What is this?
|
||||||
|
## Main file for assistants API logic
|
||||||
|
from typing import Iterable
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
from openai import OpenAI
|
||||||
|
from litellm import client
|
||||||
|
from litellm.utils import supports_httpx_timeout
|
||||||
|
from ..llms.openai import OpenAIAssistantsAPI
|
||||||
|
from ..types.llms.openai import *
|
||||||
|
from ..types.router import *
|
||||||
|
|
||||||
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
|
openai_assistants_api = OpenAIAssistantsAPI()
|
||||||
|
|
||||||
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
|
||||||
|
def get_assistants(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[SyncCursorPage[Assistant]] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_assistants(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### THREADS ###
|
||||||
|
|
||||||
|
|
||||||
|
def create_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
"""
|
||||||
|
- get the llm provider
|
||||||
|
- if openai - route it there
|
||||||
|
- pass through relevant params
|
||||||
|
|
||||||
|
```
|
||||||
|
from litellm import create_thread
|
||||||
|
|
||||||
|
create_thread(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
### OPTIONAL ###
|
||||||
|
messages = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, what is AI?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How does AI work? Explain it in simple terms."
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Thread] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.create_thread(
|
||||||
|
messages=messages,
|
||||||
|
metadata=metadata,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
"""Get the thread object, given a thread_id"""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Thread] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### MESSAGES ###
|
||||||
|
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
role: Literal["user", "assistant"],
|
||||||
|
content: str,
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
### COMMON OBJECTS ###
|
||||||
|
message_data = MessageData(
|
||||||
|
role=role, content=content, attachments=attachments, metadata=metadata
|
||||||
|
)
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[OpenAIMessage] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_data=message_data,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.get_messages(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
### RUNS ###
|
||||||
|
|
||||||
|
|
||||||
|
def run_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str] = None,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Run:
|
||||||
|
"""Run a given thread + assistant."""
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
response: Optional[Run] = None
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or os.getenv("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
organization = (
|
||||||
|
optional_params.organization
|
||||||
|
or litellm.organization
|
||||||
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
response = openai_assistants_api.run_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
|
||||||
|
custom_llm_provider
|
||||||
|
),
|
||||||
|
model="n/a",
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
content="Unsupported provider",
|
||||||
|
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
|
@ -10,7 +10,7 @@
|
||||||
import os, json, time
|
import os, json, time
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
import requests, threading
|
import requests, threading # type: ignore
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ class InMemoryCache(BaseCache):
|
||||||
return_val.append(val)
|
return_val.append(val)
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||||
# get the value
|
# get the value
|
||||||
init_value = await self.async_get_cache(key=key) or 0
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
value = init_value + value
|
value = init_value + value
|
||||||
|
@ -177,11 +177,18 @@ class RedisCache(BaseCache):
|
||||||
try:
|
try:
|
||||||
# asyncio.get_running_loop().create_task(self.ping())
|
# asyncio.get_running_loop().create_task(self.ping())
|
||||||
result = asyncio.get_running_loop().create_task(self.ping())
|
result = asyncio.get_running_loop().create_task(self.ping())
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
verbose_logger.error(
|
||||||
|
"Error connecting to Async Redis client", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
### SYNC HEALTH PING ###
|
### SYNC HEALTH PING ###
|
||||||
self.redis_client.ping()
|
try:
|
||||||
|
self.redis_client.ping()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
def init_async_client(self):
|
def init_async_client(self):
|
||||||
from ._redis import get_redis_async_client
|
from ._redis import get_redis_async_client
|
||||||
|
@ -416,12 +423,12 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer() # logging done in here
|
await self.flush_cache_buffer() # logging done in here
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
result = await redis_client.incr(name=key, amount=value)
|
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
_duration = end_time - start_time
|
_duration = end_time - start_time
|
||||||
|
@ -1375,18 +1382,41 @@ class DualCache(BaseCache):
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_batch_set_cache(
|
||||||
|
self, cache_list: list, local_only: bool = False, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Batch write values to the cache
|
||||||
|
"""
|
||||||
|
print_verbose(
|
||||||
|
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if self.in_memory_cache is not None:
|
||||||
|
await self.in_memory_cache.async_set_cache_pipeline(
|
||||||
|
cache_list=cache_list, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.redis_cache is not None and local_only == False:
|
||||||
|
await self.redis_cache.async_set_cache_pipeline(
|
||||||
|
cache_list=cache_list, ttl=kwargs.get("ttl", None)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def async_increment_cache(
|
async def async_increment_cache(
|
||||||
self, key, value: int, local_only: bool = False, **kwargs
|
self, key, value: float, local_only: bool = False, **kwargs
|
||||||
) -> int:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Key - the key in cache
|
Key - the key in cache
|
||||||
|
|
||||||
Value - int - the value you want to increment by
|
Value - float - the value you want to increment by
|
||||||
|
|
||||||
Returns - int - the incremented value
|
Returns - float - the incremented value
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result: int = value
|
result: float = value
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
result = await self.in_memory_cache.async_increment(
|
result = await self.in_memory_cache.async_increment(
|
||||||
key, value, **kwargs
|
key, value, **kwargs
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -4,18 +4,30 @@ import datetime
|
||||||
class AthinaLogger:
|
class AthinaLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"athina-api-key": self.athina_api_key,
|
"athina-api-key": self.athina_api_key,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
|
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
|
||||||
self.additional_keys = ["environment", "prompt_slug", "customer_id", "customer_user_id", "session_id", "external_reference_id", "context", "expected_response", "user_query"]
|
self.additional_keys = [
|
||||||
|
"environment",
|
||||||
|
"prompt_slug",
|
||||||
|
"customer_id",
|
||||||
|
"customer_user_id",
|
||||||
|
"session_id",
|
||||||
|
"external_reference_id",
|
||||||
|
"context",
|
||||||
|
"expected_response",
|
||||||
|
"user_query",
|
||||||
|
]
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_json = response_obj.model_dump() if response_obj else {}
|
response_json = response_obj.model_dump() if response_obj else {}
|
||||||
data = {
|
data = {
|
||||||
|
@ -23,32 +35,51 @@ class AthinaLogger:
|
||||||
"request": kwargs,
|
"request": kwargs,
|
||||||
"response": response_json,
|
"response": response_json,
|
||||||
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
||||||
"completion_tokens": response_json.get("usage", {}).get("completion_tokens"),
|
"completion_tokens": response_json.get("usage", {}).get(
|
||||||
|
"completion_tokens"
|
||||||
|
),
|
||||||
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if type(end_time) == datetime.datetime and type(start_time) == datetime.datetime:
|
if (
|
||||||
data["response_time"] = int((end_time - start_time).total_seconds() * 1000)
|
type(end_time) == datetime.datetime
|
||||||
|
and type(start_time) == datetime.datetime
|
||||||
|
):
|
||||||
|
data["response_time"] = int(
|
||||||
|
(end_time - start_time).total_seconds() * 1000
|
||||||
|
)
|
||||||
|
|
||||||
if "messages" in kwargs:
|
if "messages" in kwargs:
|
||||||
data["prompt"] = kwargs.get("messages", None)
|
data["prompt"] = kwargs.get("messages", None)
|
||||||
|
|
||||||
# Directly add tools or functions if present
|
# Directly add tools or functions if present
|
||||||
optional_params = kwargs.get("optional_params", {})
|
optional_params = kwargs.get("optional_params", {})
|
||||||
data.update((k, v) for k, v in optional_params.items() if k in ["tools", "functions"])
|
data.update(
|
||||||
|
(k, v)
|
||||||
|
for k, v in optional_params.items()
|
||||||
|
if k in ["tools", "functions"]
|
||||||
|
)
|
||||||
|
|
||||||
# Add additional metadata keys
|
# Add additional metadata keys
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||||
if metadata:
|
if metadata:
|
||||||
for key in self.additional_keys:
|
for key in self.additional_keys:
|
||||||
if key in metadata:
|
if key in metadata:
|
||||||
data[key] = metadata[key]
|
data[key] = metadata[key]
|
||||||
|
|
||||||
response = requests.post(self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
response = requests.post(
|
||||||
|
self.athina_logging_url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=json.dumps(data, default=str),
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_verbose(f"Athina Logger Error - {response.text}, {response.status_code}")
|
print_verbose(
|
||||||
|
f"Athina Logger Error - {response.text}, {response.status_code}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
print_verbose(
|
||||||
pass
|
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
class GreenscaleLogger:
|
class GreenscaleLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"api-key": self.greenscale_api_key,
|
"api-key": self.greenscale_api_key,
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||||
|
|
||||||
|
@ -19,33 +21,48 @@ class GreenscaleLogger:
|
||||||
data = {
|
data = {
|
||||||
"modelId": kwargs.get("model"),
|
"modelId": kwargs.get("model"),
|
||||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||||
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
|
"outputTokenCount": response_json.get("usage", {}).get(
|
||||||
|
"completion_tokens"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
|
data["timestamp"] = datetime.now(timezone.utc).strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%SZ"
|
||||||
if type(end_time) == datetime and type(start_time) == datetime:
|
)
|
||||||
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
|
|
||||||
|
|
||||||
|
if type(end_time) == datetime and type(start_time) == datetime:
|
||||||
|
data["invocationLatency"] = int(
|
||||||
|
(end_time - start_time).total_seconds() * 1000
|
||||||
|
)
|
||||||
|
|
||||||
# Add additional metadata keys to tags
|
# Add additional metadata keys to tags
|
||||||
tags = []
|
tags = []
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
if key.startswith("greenscale"):
|
if key.startswith("greenscale"):
|
||||||
if key == "greenscale_project":
|
if key == "greenscale_project":
|
||||||
data["project"] = value
|
data["project"] = value
|
||||||
elif key == "greenscale_application":
|
elif key == "greenscale_application":
|
||||||
data["application"] = value
|
data["application"] = value
|
||||||
else:
|
else:
|
||||||
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
|
tags.append(
|
||||||
|
{"key": key.replace("greenscale_", ""), "value": str(value)}
|
||||||
|
)
|
||||||
|
|
||||||
data["tags"] = tags
|
data["tags"] = tags
|
||||||
|
|
||||||
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
response = requests.post(
|
||||||
|
self.greenscale_logging_url,
|
||||||
|
headers=self.headers,
|
||||||
|
data=json.dumps(data, default=str),
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
|
print_verbose(
|
||||||
|
f"Greenscale Logger Error - {response.text}, {response.status_code}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
print_verbose(
|
||||||
pass
|
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Helicone
|
# On success, logs events to Helicone
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
|
|
@ -262,6 +262,23 @@ class LangFuseLogger:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tags = []
|
tags = []
|
||||||
|
try:
|
||||||
|
metadata = copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
) # Avoid modifying the original metadata
|
||||||
|
except:
|
||||||
|
new_metadata = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if (
|
||||||
|
isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
or isinstance(value, str)
|
||||||
|
or isinstance(value, int)
|
||||||
|
or isinstance(value, float)
|
||||||
|
):
|
||||||
|
new_metadata[key] = copy.deepcopy(value)
|
||||||
|
metadata = new_metadata
|
||||||
|
|
||||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||||
|
@ -272,36 +289,9 @@ class LangFuseLogger:
|
||||||
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
||||||
|
|
||||||
if supports_tags:
|
if supports_tags:
|
||||||
metadata_tags = metadata.get("tags", [])
|
metadata_tags = metadata.pop("tags", [])
|
||||||
tags = metadata_tags
|
tags = metadata_tags
|
||||||
|
|
||||||
trace_name = metadata.get("trace_name", None)
|
|
||||||
trace_id = metadata.get("trace_id", None)
|
|
||||||
existing_trace_id = metadata.get("existing_trace_id", None)
|
|
||||||
if trace_name is None and existing_trace_id is None:
|
|
||||||
# just log `litellm-{call_type}` as the trace name
|
|
||||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
|
||||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
|
||||||
|
|
||||||
if existing_trace_id is not None:
|
|
||||||
trace_params = {"id": existing_trace_id}
|
|
||||||
else: # don't overwrite an existing trace
|
|
||||||
trace_params = {
|
|
||||||
"name": trace_name,
|
|
||||||
"input": input,
|
|
||||||
"user_id": metadata.get("trace_user_id", user_id),
|
|
||||||
"id": trace_id,
|
|
||||||
"session_id": metadata.get("session_id", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
if level == "ERROR":
|
|
||||||
trace_params["status_message"] = output
|
|
||||||
else:
|
|
||||||
trace_params["output"] = output
|
|
||||||
|
|
||||||
cost = kwargs.get("response_cost", None)
|
|
||||||
print_verbose(f"trace: {cost}")
|
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
# Clean Metadata before logging - never log raw metadata
|
||||||
# the raw metadata can contain circular references which leads to infinite recursion
|
# the raw metadata can contain circular references which leads to infinite recursion
|
||||||
# we clean out all extra litellm metadata params before logging
|
# we clean out all extra litellm metadata params before logging
|
||||||
|
@ -328,6 +318,67 @@ class LangFuseLogger:
|
||||||
else:
|
else:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
|
session_id = clean_metadata.pop("session_id", None)
|
||||||
|
trace_name = clean_metadata.pop("trace_name", None)
|
||||||
|
trace_id = clean_metadata.pop("trace_id", None)
|
||||||
|
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||||
|
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||||
|
|
||||||
|
if trace_name is None and existing_trace_id is None:
|
||||||
|
# just log `litellm-{call_type}` as the trace name
|
||||||
|
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||||
|
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||||
|
|
||||||
|
if existing_trace_id is not None:
|
||||||
|
trace_params = {"id": existing_trace_id}
|
||||||
|
|
||||||
|
# Update the following keys for this trace
|
||||||
|
for metadata_param_key in update_trace_keys:
|
||||||
|
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||||
|
if trace_param_key not in trace_params:
|
||||||
|
updated_trace_value = clean_metadata.pop(
|
||||||
|
metadata_param_key, None
|
||||||
|
)
|
||||||
|
if updated_trace_value is not None:
|
||||||
|
trace_params[trace_param_key] = updated_trace_value
|
||||||
|
|
||||||
|
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||||
|
for key in list(
|
||||||
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
|
clean_metadata.pop(key, None)
|
||||||
|
|
||||||
|
# Special keys that are found in the function arguments and not the metadata
|
||||||
|
if "input" in update_trace_keys:
|
||||||
|
trace_params["input"] = input
|
||||||
|
if "output" in update_trace_keys:
|
||||||
|
trace_params["output"] = output
|
||||||
|
else: # don't overwrite an existing trace
|
||||||
|
trace_params = {
|
||||||
|
"id": trace_id,
|
||||||
|
"name": trace_name,
|
||||||
|
"session_id": session_id,
|
||||||
|
"input": input,
|
||||||
|
"version": clean_metadata.pop(
|
||||||
|
"trace_version", clean_metadata.get("version", None)
|
||||||
|
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
for key in list(
|
||||||
|
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||||
|
):
|
||||||
|
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||||
|
key, None
|
||||||
|
)
|
||||||
|
|
||||||
|
if level == "ERROR":
|
||||||
|
trace_params["status_message"] = output
|
||||||
|
else:
|
||||||
|
trace_params["output"] = output
|
||||||
|
|
||||||
|
cost = kwargs.get("response_cost", None)
|
||||||
|
print_verbose(f"trace: {cost}")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm._langfuse_default_tags is not None
|
litellm._langfuse_default_tags is not None
|
||||||
and isinstance(litellm._langfuse_default_tags, list)
|
and isinstance(litellm._langfuse_default_tags, list)
|
||||||
|
@ -387,7 +438,7 @@ class LangFuseLogger:
|
||||||
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
||||||
"total_cost": cost if supports_costs else None,
|
"total_cost": cost if supports_costs else None,
|
||||||
}
|
}
|
||||||
generation_name = metadata.get("generation_name", None)
|
generation_name = clean_metadata.pop("generation_name", None)
|
||||||
if generation_name is None:
|
if generation_name is None:
|
||||||
# just log `litellm-{call_type}` as the generation name
|
# just log `litellm-{call_type}` as the generation name
|
||||||
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||||
|
@ -402,7 +453,7 @@ class LangFuseLogger:
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"name": generation_name,
|
"name": generation_name,
|
||||||
"id": metadata.get("generation_id", generation_id),
|
"id": clean_metadata.pop("generation_id", generation_id),
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
"model": kwargs["model"],
|
"model": kwargs["model"],
|
||||||
|
@ -412,10 +463,11 @@ class LangFuseLogger:
|
||||||
"usage": usage,
|
"usage": usage,
|
||||||
"metadata": clean_metadata,
|
"metadata": clean_metadata,
|
||||||
"level": level,
|
"level": level,
|
||||||
|
"version": clean_metadata.pop("version", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
if supports_prompt:
|
if supports_prompt:
|
||||||
generation_params["prompt"] = metadata.get("prompt", None)
|
generation_params["prompt"] = clean_metadata.pop("prompt", None)
|
||||||
|
|
||||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||||
generation_params["status_message"] = output
|
generation_params["status_message"] = output
|
||||||
|
@ -426,7 +478,7 @@ class LangFuseLogger:
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_client = trace.generation(**generation_params)
|
generation_client = trace.generation(**generation_params)
|
||||||
|
|
||||||
return generation_client.trace_id, generation_id
|
return generation_client.trace_id, generation_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langsmith
|
# On success, logs events to Langsmith
|
||||||
import dotenv, os
|
import dotenv, os # type: ignore
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import types
|
import types
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def is_serializable(value):
|
def is_serializable(value):
|
||||||
|
@ -79,8 +78,6 @@ class LangsmithLogger:
|
||||||
except:
|
except:
|
||||||
response_obj = response_obj.dict() # type: ignore
|
response_obj = response_obj.dict() # type: ignore
|
||||||
|
|
||||||
print(f"response_obj: {response_obj}")
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"name": run_name,
|
"name": run_name,
|
||||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||||
|
@ -90,7 +87,6 @@ class LangsmithLogger:
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
}
|
}
|
||||||
print(f"data: {data}")
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.smith.langchain.com/runs",
|
"https://api.smith.langchain.com/runs",
|
||||||
|
|
|
@ -4,7 +4,6 @@ from datetime import datetime, timezone
|
||||||
import traceback
|
import traceback
|
||||||
import dotenv
|
import dotenv
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
|
@ -18,13 +17,33 @@ def parse_usage(usage):
|
||||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def parse_tool_calls(tool_calls):
|
||||||
|
if tool_calls is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tool_call(tool_call):
|
||||||
|
|
||||||
|
serialized = {
|
||||||
|
"type": tool_call.type,
|
||||||
|
"id": tool_call.id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
|
||||||
def parse_messages(input):
|
def parse_messages(input):
|
||||||
|
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clean_message(message):
|
def clean_message(message):
|
||||||
# if is strin, return as is
|
# if is string, return as is
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
@ -38,9 +57,7 @@ def parse_messages(input):
|
||||||
|
|
||||||
# Only add tool_calls and function_call to res if they are set
|
# Only add tool_calls and function_call to res if they are set
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
serialized["tool_calls"] = message.get("tool_calls")
|
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||||
if message.get("function_call"):
|
|
||||||
serialized["function_call"] = message.get("function_call")
|
|
||||||
|
|
||||||
return serialized
|
return serialized
|
||||||
|
|
||||||
|
@ -93,8 +110,13 @@ class LunaryLogger:
|
||||||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
optional_params = kwargs.get("optional_params", {})
|
||||||
metadata = litellm_params.get("metadata", {}) or {}
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
|
|
||||||
|
if optional_params:
|
||||||
|
# merge into extra
|
||||||
|
extra = {**extra, **optional_params}
|
||||||
|
|
||||||
tags = litellm_params.pop("tags", None) or []
|
tags = litellm_params.pop("tags", None) or []
|
||||||
|
|
||||||
if extra:
|
if extra:
|
||||||
|
@ -104,7 +126,7 @@ class LunaryLogger:
|
||||||
|
|
||||||
# keep only serializable types
|
# keep only serializable types
|
||||||
for param, value in extra.items():
|
for param, value in extra.items():
|
||||||
if not isinstance(value, (str, int, bool, float)):
|
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||||
try:
|
try:
|
||||||
extra[param] = str(value)
|
extra[param] = str(value)
|
||||||
except:
|
except:
|
||||||
|
@ -140,7 +162,7 @@ class LunaryLogger:
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
runtime="litellm",
|
runtime="litellm",
|
||||||
tags=tags,
|
tags=tags,
|
||||||
extra=extra,
|
params=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.lunary_client.track_event(
|
self.lunary_client.track_event(
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
||||||
|
|
||||||
import dotenv, os, json
|
import dotenv, os, json
|
||||||
import requests
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
@ -38,7 +37,7 @@ class OpenMeterLogger(CustomLogger):
|
||||||
in the environment
|
in the environment
|
||||||
"""
|
"""
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
if litellm.get_secret("OPENMETER_API_KEY", None) is None:
|
if os.getenv("OPENMETER_API_KEY", None) is None:
|
||||||
missing_keys.append("OPENMETER_API_KEY")
|
missing_keys.append("OPENMETER_API_KEY")
|
||||||
|
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
|
@ -60,47 +59,56 @@ class OpenMeterLogger(CustomLogger):
|
||||||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
|
||||||
|
if not subject:
|
||||||
|
raise Exception("OpenMeter: user is required")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"specversion": "1.0",
|
"specversion": "1.0",
|
||||||
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
||||||
"id": call_id,
|
"id": call_id,
|
||||||
"time": dt,
|
"time": dt,
|
||||||
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param
|
"subject": subject,
|
||||||
"source": "litellm-proxy",
|
"source": "litellm-proxy",
|
||||||
"data": {"model": model, "cost": cost, **usage},
|
"data": {"model": model, "cost": cost, **usage},
|
||||||
}
|
}
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
_url = litellm.get_secret(
|
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||||
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
|
|
||||||
)
|
|
||||||
if _url.endswith("/"):
|
if _url.endswith("/"):
|
||||||
_url += "api/v1/events"
|
_url += "api/v1/events"
|
||||||
else:
|
else:
|
||||||
_url += "/api/v1/events"
|
_url += "/api/v1/events"
|
||||||
|
|
||||||
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
api_key = os.getenv("OPENMETER_API_KEY")
|
||||||
|
|
||||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||||
self.sync_http_handler.post(
|
_headers = {
|
||||||
url=_url,
|
"Content-Type": "application/cloudevents+json",
|
||||||
data=_data,
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
headers={
|
}
|
||||||
"Content-Type": "application/cloudevents+json",
|
|
||||||
"Authorization": "Bearer {}".format(api_key),
|
try:
|
||||||
},
|
response = self.sync_http_handler.post(
|
||||||
)
|
url=_url,
|
||||||
|
data=json.dumps(_data),
|
||||||
|
headers=_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||||
|
raise e
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
_url = litellm.get_secret(
|
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||||
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
|
|
||||||
)
|
|
||||||
if _url.endswith("/"):
|
if _url.endswith("/"):
|
||||||
_url += "api/v1/events"
|
_url += "api/v1/events"
|
||||||
else:
|
else:
|
||||||
_url += "/api/v1/events"
|
_url += "/api/v1/events"
|
||||||
|
|
||||||
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
api_key = os.getenv("OPENMETER_API_KEY")
|
||||||
|
|
||||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||||
_headers = {
|
_headers = {
|
||||||
|
@ -117,7 +125,6 @@ class OpenMeterLogger(CustomLogger):
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nAn Exception Occurred - {str(e)}")
|
|
||||||
if hasattr(response, "text"):
|
if hasattr(response, "text"):
|
||||||
print(f"\nError Message: {response.text}")
|
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# On success, log events to Prometheus
|
# On success, log events to Prometheus
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -19,7 +19,6 @@ class PrometheusLogger:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
print(f"in init prometheus metrics")
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -183,7 +183,6 @@ class PrometheusServicesLogger:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||||
print(f"received error payload: {payload.error}")
|
|
||||||
if self.mock_testing:
|
if self.mock_testing:
|
||||||
self.mock_testing_failure_calls += 1
|
self.mock_testing_failure_calls += 1
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
class PromptLayerLogger:
|
class PromptLayerLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -32,7 +33,11 @@ class PromptLayerLogger:
|
||||||
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
||||||
|
|
||||||
# Remove "pl_tags" from metadata
|
# Remove "pl_tags" from metadata
|
||||||
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"}
|
metadata = {
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs["litellm_params"]["metadata"].items()
|
||||||
|
if k != "pl_tags"
|
||||||
|
}
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -1,25 +1,82 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# Class for sending Slack Alerts #
|
# Class for sending Slack Alerts #
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import copy
|
|
||||||
import traceback
|
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
import litellm
|
import litellm, threading
|
||||||
from typing import List, Literal, Any, Union, Optional, Dict
|
from typing import List, Literal, Any, Union, Optional, Dict
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
import datetime
|
import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime as dt, timedelta
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class SlackAlerting:
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlertingArgs(LiteLLMBase):
|
||||||
|
daily_report_frequency: int = 12 * 60 * 60 # 12 hours
|
||||||
|
report_check_interval: int = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
|
class DeploymentMetrics(LiteLLMBase):
|
||||||
|
"""
|
||||||
|
Metrics per deployment, stored in cache
|
||||||
|
|
||||||
|
Used for daily reporting
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""id of deployment in router model list"""
|
||||||
|
|
||||||
|
failed_request: bool
|
||||||
|
"""did it fail the request?"""
|
||||||
|
|
||||||
|
latency_per_output_token: Optional[float]
|
||||||
|
"""latency/output token of deployment"""
|
||||||
|
|
||||||
|
updated_at: dt
|
||||||
|
"""Current time of deployment being updated"""
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlertingCacheKeys(Enum):
|
||||||
|
"""
|
||||||
|
Enum for deployment daily metrics keys - {deployment_id}:{enum}
|
||||||
|
"""
|
||||||
|
|
||||||
|
failed_requests_key = "failed_requests_daily_metrics"
|
||||||
|
latency_key = "latency_daily_metrics"
|
||||||
|
report_sent_key = "daily_metrics_report_sent"
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAlerting(CustomLogger):
|
||||||
|
"""
|
||||||
|
Class for sending Slack Alerts
|
||||||
|
"""
|
||||||
|
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
alerting_threshold: float = 300,
|
internal_usage_cache: Optional[DualCache] = None,
|
||||||
|
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||||
alerting: Optional[List] = [],
|
alerting: Optional[List] = [],
|
||||||
alert_types: Optional[
|
alert_types: Optional[
|
||||||
List[
|
List[
|
||||||
|
@ -29,6 +86,7 @@ class SlackAlerting:
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
|
"daily_reports",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
] = [
|
] = [
|
||||||
|
@ -37,31 +95,23 @@ class SlackAlerting:
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
|
"daily_reports",
|
||||||
],
|
],
|
||||||
alert_to_webhook_url: Optional[
|
alert_to_webhook_url: Optional[
|
||||||
Dict
|
Dict
|
||||||
] = None, # if user wants to separate alerts to diff channels
|
] = None, # if user wants to separate alerts to diff channels
|
||||||
|
alerting_args={},
|
||||||
|
default_webhook_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
self.alert_types = alert_types
|
self.alert_types = alert_types
|
||||||
self.internal_usage_cache = DualCache()
|
self.internal_usage_cache = internal_usage_cache or DualCache()
|
||||||
self.async_http_handler = AsyncHTTPHandler()
|
self.async_http_handler = AsyncHTTPHandler()
|
||||||
self.alert_to_webhook_url = alert_to_webhook_url
|
self.alert_to_webhook_url = alert_to_webhook_url
|
||||||
self.langfuse_logger = None
|
self.is_running = False
|
||||||
|
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||||
try:
|
self.default_webhook_url = default_webhook_url
|
||||||
from litellm.integrations.langfuse import LangFuseLogger
|
|
||||||
|
|
||||||
self.langfuse_logger = LangFuseLogger(
|
|
||||||
os.getenv("LANGFUSE_PUBLIC_KEY"),
|
|
||||||
os.getenv("LANGFUSE_SECRET_KEY"),
|
|
||||||
flush_interval=1,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_values(
|
def update_values(
|
||||||
self,
|
self,
|
||||||
|
@ -69,6 +119,7 @@ class SlackAlerting:
|
||||||
alerting_threshold: Optional[float] = None,
|
alerting_threshold: Optional[float] = None,
|
||||||
alert_types: Optional[List] = None,
|
alert_types: Optional[List] = None,
|
||||||
alert_to_webhook_url: Optional[Dict] = None,
|
alert_to_webhook_url: Optional[Dict] = None,
|
||||||
|
alerting_args: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
if alerting is not None:
|
if alerting is not None:
|
||||||
self.alerting = alerting
|
self.alerting = alerting
|
||||||
|
@ -76,7 +127,8 @@ class SlackAlerting:
|
||||||
self.alerting_threshold = alerting_threshold
|
self.alerting_threshold = alerting_threshold
|
||||||
if alert_types is not None:
|
if alert_types is not None:
|
||||||
self.alert_types = alert_types
|
self.alert_types = alert_types
|
||||||
|
if alerting_args is not None:
|
||||||
|
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||||
if alert_to_webhook_url is not None:
|
if alert_to_webhook_url is not None:
|
||||||
# update the dict
|
# update the dict
|
||||||
if self.alert_to_webhook_url is None:
|
if self.alert_to_webhook_url is None:
|
||||||
|
@ -103,72 +155,23 @@ class SlackAlerting:
|
||||||
|
|
||||||
def _add_langfuse_trace_id_to_alert(
|
def _add_langfuse_trace_id_to_alert(
|
||||||
self,
|
self,
|
||||||
request_info: str,
|
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
kwargs: Optional[dict] = None,
|
) -> Optional[str]:
|
||||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
"""
|
||||||
start_time: Optional[datetime.datetime] = None,
|
Returns langfuse trace url
|
||||||
end_time: Optional[datetime.datetime] = None,
|
"""
|
||||||
):
|
# do nothing for now
|
||||||
import uuid
|
if (
|
||||||
|
request_data is not None
|
||||||
|
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||||
|
):
|
||||||
|
trace_id = request_data["metadata"]["trace_id"]
|
||||||
|
if litellm.utils.langFuseLogger is not None:
|
||||||
|
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||||
|
return f"{base_url}/trace/{trace_id}"
|
||||||
|
return None
|
||||||
|
|
||||||
# For now: do nothing as we're debugging why this is not working as expected
|
def _response_taking_too_long_callback_helper(
|
||||||
if request_data is not None:
|
|
||||||
trace_id = request_data.get("metadata", {}).get(
|
|
||||||
"trace_id", None
|
|
||||||
) # get langfuse trace id
|
|
||||||
if trace_id is None:
|
|
||||||
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
|
||||||
request_data["metadata"]["trace_id"] = trace_id
|
|
||||||
elif kwargs is not None:
|
|
||||||
_litellm_params = kwargs.get("litellm_params", {})
|
|
||||||
trace_id = _litellm_params.get("metadata", {}).get(
|
|
||||||
"trace_id", None
|
|
||||||
) # get langfuse trace id
|
|
||||||
if trace_id is None:
|
|
||||||
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
|
||||||
_litellm_params["metadata"]["trace_id"] = trace_id
|
|
||||||
|
|
||||||
# Log hanging request as an error on langfuse
|
|
||||||
if type == "hanging_request":
|
|
||||||
if self.langfuse_logger is not None:
|
|
||||||
_logging_kwargs = copy.deepcopy(request_data)
|
|
||||||
if _logging_kwargs is None:
|
|
||||||
_logging_kwargs = {}
|
|
||||||
_logging_kwargs["litellm_params"] = {}
|
|
||||||
request_data = request_data or {}
|
|
||||||
_logging_kwargs["litellm_params"]["metadata"] = request_data.get(
|
|
||||||
"metadata", {}
|
|
||||||
)
|
|
||||||
# log to langfuse in a separate thread
|
|
||||||
import threading
|
|
||||||
|
|
||||||
threading.Thread(
|
|
||||||
target=self.langfuse_logger.log_event,
|
|
||||||
args=(
|
|
||||||
_logging_kwargs,
|
|
||||||
None,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
None,
|
|
||||||
print,
|
|
||||||
"ERROR",
|
|
||||||
"Requests is hanging",
|
|
||||||
),
|
|
||||||
).start()
|
|
||||||
|
|
||||||
_langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
|
||||||
_langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
|
|
||||||
|
|
||||||
# langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
|
|
||||||
|
|
||||||
_langfuse_url = (
|
|
||||||
f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
|
|
||||||
)
|
|
||||||
request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
|
|
||||||
return request_info
|
|
||||||
|
|
||||||
def _response_taking_too_long_callback(
|
|
||||||
self,
|
self,
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
start_time,
|
start_time,
|
||||||
|
@ -233,7 +236,7 @@ class SlackAlerting:
|
||||||
return
|
return
|
||||||
|
|
||||||
time_difference_float, model, api_base, messages = (
|
time_difference_float, model, api_base, messages = (
|
||||||
self._response_taking_too_long_callback(
|
self._response_taking_too_long_callback_helper(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
@ -242,10 +245,6 @@ class SlackAlerting:
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||||
if time_difference_float > self.alerting_threshold:
|
if time_difference_float > self.alerting_threshold:
|
||||||
if "langfuse" in litellm.success_callback:
|
|
||||||
request_info = self._add_langfuse_trace_id_to_alert(
|
|
||||||
request_info=request_info, kwargs=kwargs, type="slow_response"
|
|
||||||
)
|
|
||||||
# add deployment latencies to alert
|
# add deployment latencies to alert
|
||||||
if (
|
if (
|
||||||
kwargs is not None
|
kwargs is not None
|
||||||
|
@ -253,6 +252,9 @@ class SlackAlerting:
|
||||||
and "metadata" in kwargs["litellm_params"]
|
and "metadata" in kwargs["litellm_params"]
|
||||||
):
|
):
|
||||||
_metadata = kwargs["litellm_params"]["metadata"]
|
_metadata = kwargs["litellm_params"]["metadata"]
|
||||||
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
|
request_info=request_info, metadata=_metadata
|
||||||
|
)
|
||||||
|
|
||||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||||
metadata=_metadata
|
metadata=_metadata
|
||||||
|
@ -267,8 +269,178 @@ class SlackAlerting:
|
||||||
alert_type="llm_too_slow",
|
alert_type="llm_too_slow",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def log_failure_event(self, original_exception: Exception):
|
async def async_update_daily_reports(
|
||||||
pass
|
self, deployment_metrics: DeploymentMetrics
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Store the perf by deployment in cache
|
||||||
|
- Number of failed requests per deployment
|
||||||
|
- Latency / output tokens per deployment
|
||||||
|
|
||||||
|
'deployment_id:daily_metrics:failed_requests'
|
||||||
|
'deployment_id:daily_metrics:latency_per_output_token'
|
||||||
|
|
||||||
|
Returns
|
||||||
|
int - count of metrics set (1 - if just latency, 2 - if failed + latency)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_val = 0
|
||||||
|
try:
|
||||||
|
## FAILED REQUESTS ##
|
||||||
|
if deployment_metrics.failed_request:
|
||||||
|
await self.internal_usage_cache.async_increment_cache(
|
||||||
|
key="{}:{}".format(
|
||||||
|
deployment_metrics.id,
|
||||||
|
SlackAlertingCacheKeys.failed_requests_key.value,
|
||||||
|
),
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return_val += 1
|
||||||
|
|
||||||
|
## LATENCY ##
|
||||||
|
if deployment_metrics.latency_per_output_token is not None:
|
||||||
|
await self.internal_usage_cache.async_increment_cache(
|
||||||
|
key="{}:{}".format(
|
||||||
|
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
|
||||||
|
),
|
||||||
|
value=deployment_metrics.latency_per_output_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return_val += 1
|
||||||
|
|
||||||
|
return return_val
|
||||||
|
except Exception as e:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def send_daily_reports(self, router) -> bool:
|
||||||
|
"""
|
||||||
|
Send a daily report on:
|
||||||
|
- Top 5 deployments with most failed requests
|
||||||
|
- Top 5 slowest deployments (normalized by latency/output tokens)
|
||||||
|
|
||||||
|
Get the value from redis cache (if available) or in-memory and send it
|
||||||
|
|
||||||
|
Cleanup:
|
||||||
|
- reset values in cache -> prevent memory leak
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True -> if successfuly sent
|
||||||
|
False -> if not sent
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids = router.get_model_ids()
|
||||||
|
|
||||||
|
# get keys
|
||||||
|
failed_request_keys = [
|
||||||
|
"{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
|
||||||
|
for id in ids
|
||||||
|
]
|
||||||
|
latency_keys = [
|
||||||
|
"{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
|
||||||
|
]
|
||||||
|
|
||||||
|
combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
|
||||||
|
|
||||||
|
combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
|
||||||
|
keys=combined_metrics_keys
|
||||||
|
) # [1, 2, None, ..]
|
||||||
|
|
||||||
|
all_none = True
|
||||||
|
for val in combined_metrics_values:
|
||||||
|
if val is not None:
|
||||||
|
all_none = False
|
||||||
|
|
||||||
|
if all_none:
|
||||||
|
return False
|
||||||
|
|
||||||
|
failed_request_values = combined_metrics_values[
|
||||||
|
: len(failed_request_keys)
|
||||||
|
] # # [1, 2, None, ..]
|
||||||
|
latency_values = combined_metrics_values[len(failed_request_keys) :]
|
||||||
|
|
||||||
|
# find top 5 failed
|
||||||
|
## Replace None values with a placeholder value (-1 in this case)
|
||||||
|
placeholder_value = 0
|
||||||
|
replaced_failed_values = [
|
||||||
|
value if value is not None else placeholder_value
|
||||||
|
for value in failed_request_values
|
||||||
|
]
|
||||||
|
|
||||||
|
## Get the indices of top 5 keys with the highest numerical values (ignoring None values)
|
||||||
|
top_5_failed = sorted(
|
||||||
|
range(len(replaced_failed_values)),
|
||||||
|
key=lambda i: replaced_failed_values[i],
|
||||||
|
reverse=True,
|
||||||
|
)[:5]
|
||||||
|
|
||||||
|
# find top 5 slowest
|
||||||
|
# Replace None values with a placeholder value (-1 in this case)
|
||||||
|
placeholder_value = 0
|
||||||
|
replaced_slowest_values = [
|
||||||
|
value if value is not None else placeholder_value
|
||||||
|
for value in latency_values
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get the indices of top 5 values with the highest numerical values (ignoring None values)
|
||||||
|
top_5_slowest = sorted(
|
||||||
|
range(len(replaced_slowest_values)),
|
||||||
|
key=lambda i: replaced_slowest_values[i],
|
||||||
|
reverse=True,
|
||||||
|
)[:5]
|
||||||
|
|
||||||
|
# format alert -> return the litellm model name + api base
|
||||||
|
message = f"\n\nHere are today's key metrics 📈: \n\n"
|
||||||
|
|
||||||
|
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n"
|
||||||
|
for i in range(len(top_5_failed)):
|
||||||
|
key = failed_request_keys[top_5_failed[i]].split(":")[0]
|
||||||
|
_deployment = router.get_model_info(key)
|
||||||
|
if isinstance(_deployment, dict):
|
||||||
|
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
api_base = litellm.get_api_base(
|
||||||
|
model=deployment_name,
|
||||||
|
optional_params=(
|
||||||
|
_deployment["litellm_params"] if _deployment is not None else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if api_base is None:
|
||||||
|
api_base = ""
|
||||||
|
value = replaced_failed_values[top_5_failed[i]]
|
||||||
|
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
|
||||||
|
|
||||||
|
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n"
|
||||||
|
for i in range(len(top_5_slowest)):
|
||||||
|
key = latency_keys[top_5_slowest[i]].split(":")[0]
|
||||||
|
_deployment = router.get_model_info(key)
|
||||||
|
if _deployment is not None:
|
||||||
|
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||||
|
else:
|
||||||
|
deployment_name = ""
|
||||||
|
api_base = litellm.get_api_base(
|
||||||
|
model=deployment_name,
|
||||||
|
optional_params=(
|
||||||
|
_deployment["litellm_params"] if _deployment is not None else {}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
value = round(replaced_slowest_values[top_5_slowest[i]], 3)
|
||||||
|
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
|
||||||
|
|
||||||
|
# cache cleanup -> reset values to 0
|
||||||
|
latency_cache_keys = [(key, 0) for key in latency_keys]
|
||||||
|
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
||||||
|
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
||||||
|
await self.internal_usage_cache.async_batch_set_cache(
|
||||||
|
cache_list=combined_metrics_cache_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
# send alert
|
||||||
|
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def response_taking_too_long(
|
async def response_taking_too_long(
|
||||||
self,
|
self,
|
||||||
|
@ -326,6 +498,11 @@ class SlackAlerting:
|
||||||
# in that case we fallback to the api base set in the request metadata
|
# in that case we fallback to the api base set in the request metadata
|
||||||
_metadata = request_data["metadata"]
|
_metadata = request_data["metadata"]
|
||||||
_api_base = _metadata.get("api_base", "")
|
_api_base = _metadata.get("api_base", "")
|
||||||
|
|
||||||
|
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||||
|
request_info=request_info, metadata=_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if _api_base is None:
|
if _api_base is None:
|
||||||
_api_base = ""
|
_api_base = ""
|
||||||
request_info += f"\nAPI Base: `{_api_base}`"
|
request_info += f"\nAPI Base: `{_api_base}`"
|
||||||
|
@ -335,14 +512,13 @@ class SlackAlerting:
|
||||||
)
|
)
|
||||||
|
|
||||||
if "langfuse" in litellm.success_callback:
|
if "langfuse" in litellm.success_callback:
|
||||||
request_info = self._add_langfuse_trace_id_to_alert(
|
langfuse_url = self._add_langfuse_trace_id_to_alert(
|
||||||
request_info=request_info,
|
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
type="hanging_request",
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if langfuse_url is not None:
|
||||||
|
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
||||||
|
|
||||||
# add deployment latencies to alert
|
# add deployment latencies to alert
|
||||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||||
metadata=request_data.get("metadata", {})
|
metadata=request_data.get("metadata", {})
|
||||||
|
@ -475,6 +651,53 @@ class SlackAlerting:
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
async def model_added_alert(self, model_name: str, litellm_model_name: str):
|
||||||
|
model_info = litellm.model_cost.get(litellm_model_name, {})
|
||||||
|
model_info_str = ""
|
||||||
|
for k, v in model_info.items():
|
||||||
|
if k == "input_cost_per_token" or k == "output_cost_per_token":
|
||||||
|
# when converting to string it should not be 1.63e-06
|
||||||
|
v = "{:.8f}".format(v)
|
||||||
|
|
||||||
|
model_info_str += f"{k}: {v}\n"
|
||||||
|
|
||||||
|
message = f"""
|
||||||
|
*🚅 New Model Added*
|
||||||
|
Model Name: `{model_name}`
|
||||||
|
|
||||||
|
Usage OpenAI Python SDK:
|
||||||
|
```
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="your_api_key",
|
||||||
|
base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="{model_name}", # model to send to the proxy
|
||||||
|
messages = [
|
||||||
|
{{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Model Info:
|
||||||
|
```
|
||||||
|
{model_info_str}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.send_alert(
|
||||||
|
message=message, level="Low", alert_type="new_model_added"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def model_removed_alert(self, model_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
async def send_alert(
|
async def send_alert(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
|
@ -485,7 +708,11 @@ class SlackAlerting:
|
||||||
"llm_requests_hanging",
|
"llm_requests_hanging",
|
||||||
"budget_alerts",
|
"budget_alerts",
|
||||||
"db_exceptions",
|
"db_exceptions",
|
||||||
|
"daily_reports",
|
||||||
|
"new_model_added",
|
||||||
|
"cooldown_deployment",
|
||||||
],
|
],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||||
|
@ -510,9 +737,16 @@ class SlackAlerting:
|
||||||
# Get the current timestamp
|
# Get the current timestamp
|
||||||
current_time = datetime.now().strftime("%H:%M:%S")
|
current_time = datetime.now().strftime("%H:%M:%S")
|
||||||
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||||||
formatted_message = (
|
if alert_type == "daily_reports" or alert_type == "new_model_added":
|
||||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
formatted_message = message
|
||||||
)
|
else:
|
||||||
|
formatted_message = (
|
||||||
|
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||||
if _proxy_base_url is not None:
|
if _proxy_base_url is not None:
|
||||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||||
|
|
||||||
|
@ -522,6 +756,8 @@ class SlackAlerting:
|
||||||
and alert_type in self.alert_to_webhook_url
|
and alert_type in self.alert_to_webhook_url
|
||||||
):
|
):
|
||||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||||
|
elif self.default_webhook_url is not None:
|
||||||
|
slack_webhook_url = self.default_webhook_url
|
||||||
else:
|
else:
|
||||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||||
|
|
||||||
|
@ -539,3 +775,113 @@ class SlackAlerting:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print("Error sending slack alert. Error=", response.text) # noqa
|
print("Error sending slack alert. Error=", response.text) # noqa
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""Log deployment latency"""
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
model_id = (
|
||||||
|
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||||
|
)
|
||||||
|
response_s: timedelta = end_time - start_time
|
||||||
|
|
||||||
|
final_value = response_s
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, litellm.ModelResponse):
|
||||||
|
completion_tokens = response_obj.usage.completion_tokens
|
||||||
|
final_value = float(response_s.total_seconds() / completion_tokens)
|
||||||
|
|
||||||
|
await self.async_update_daily_reports(
|
||||||
|
DeploymentMetrics(
|
||||||
|
id=model_id,
|
||||||
|
failed_request=False,
|
||||||
|
latency_per_output_token=final_value,
|
||||||
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""Log failure + deployment latency"""
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
model_id = (
|
||||||
|
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||||
|
)
|
||||||
|
await self.async_update_daily_reports(
|
||||||
|
DeploymentMetrics(
|
||||||
|
id=model_id,
|
||||||
|
failed_request=True,
|
||||||
|
latency_per_output_token=None,
|
||||||
|
updated_at=litellm.utils.get_utc_datetime(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "llm_exceptions" in self.alert_types:
|
||||||
|
original_exception = kwargs.get("exception", None)
|
||||||
|
|
||||||
|
await self.send_alert(
|
||||||
|
message="LLM API Failure - " + str(original_exception),
|
||||||
|
level="High",
|
||||||
|
alert_type="llm_exceptions",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- True -> report sent
|
||||||
|
- False -> report not sent
|
||||||
|
"""
|
||||||
|
report_sent_bool = False
|
||||||
|
|
||||||
|
report_sent = await self.internal_usage_cache.async_get_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value
|
||||||
|
) # None | datetime
|
||||||
|
|
||||||
|
current_time = litellm.utils.get_utc_datetime()
|
||||||
|
|
||||||
|
if report_sent is None:
|
||||||
|
_current_time = current_time.isoformat()
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||||
|
value=_current_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# check if current time - interval >= time last sent
|
||||||
|
delta = current_time - timedelta(
|
||||||
|
seconds=self.alerting_args.daily_report_frequency
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(report_sent, str):
|
||||||
|
report_sent = dt.fromisoformat(report_sent)
|
||||||
|
|
||||||
|
if delta >= report_sent:
|
||||||
|
# Sneak in the reporting logic here
|
||||||
|
await self.send_daily_reports(router=llm_router)
|
||||||
|
# Also, don't forget to update the report_sent time after sending the report!
|
||||||
|
_current_time = current_time.isoformat()
|
||||||
|
await self.internal_usage_cache.async_set_cache(
|
||||||
|
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||||
|
value=_current_time,
|
||||||
|
)
|
||||||
|
report_sent_bool = True
|
||||||
|
|
||||||
|
return report_sent_bool
|
||||||
|
|
||||||
|
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
||||||
|
"""
|
||||||
|
If 'daily_reports' enabled
|
||||||
|
|
||||||
|
Ping redis cache every 5 minutes to check if we should send the report
|
||||||
|
|
||||||
|
If yes -> call send_daily_report()
|
||||||
|
"""
|
||||||
|
if llm_router is None or self.alert_types is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "daily_reports" in self.alert_types:
|
||||||
|
while True:
|
||||||
|
await self._run_scheduler_helper(llm_router=llm_router)
|
||||||
|
interval = random.randint(
|
||||||
|
self.alerting_args.report_check_interval - 3,
|
||||||
|
self.alerting_args.report_check_interval + 3,
|
||||||
|
) # shuffle to prevent collisions
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
return
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os, types, traceback
|
import os, types, traceback
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, httpx
|
import time, httpx # type: ignore
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message
|
from litellm.utils import ModelResponse, Choices, Message
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AlephAlphaError(Exception):
|
class AlephAlphaError(Exception):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy
|
import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
@ -9,7 +9,7 @@ import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
|
@ -84,6 +84,51 @@ class AnthropicConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
if (
|
||||||
|
value == "\n"
|
||||||
|
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
continue
|
||||||
|
value = [value]
|
||||||
|
elif isinstance(value, list):
|
||||||
|
new_v = []
|
||||||
|
for v in value:
|
||||||
|
if (
|
||||||
|
v == "\n"
|
||||||
|
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
continue
|
||||||
|
new_v.append(v)
|
||||||
|
if len(new_v) > 0:
|
||||||
|
value = new_v
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
# makes headers for API call
|
||||||
def validate_environment(api_key, user_headers):
|
def validate_environment(api_key, user_headers):
|
||||||
|
@ -139,11 +184,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
message=str(completion_response["error"]),
|
message=str(completion_response["error"]),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
elif len(completion_response["content"]) == 0:
|
|
||||||
raise AnthropicError(
|
|
||||||
message="No content in response",
|
|
||||||
status_code=response.status_code,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
text_content = ""
|
text_content = ""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Union, Any
|
from typing import Optional, Union, Any, Literal
|
||||||
import types, requests
|
import types, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
|
@ -8,14 +8,16 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
get_secret,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional, BinaryIO
|
from typing import Callable, Optional, BinaryIO
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
import litellm, json
|
import litellm, json
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||||
import uuid
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIError(Exception):
|
class AzureOpenAIError(Exception):
|
||||||
|
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
return azure_client_params
|
return azure_client_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
|
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||||
|
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||||
|
|
||||||
|
if azure_client_id is None or azure_tenant is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||||
|
)
|
||||||
|
|
||||||
|
oidc_token = get_secret(azure_ad_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
)
|
||||||
|
|
||||||
|
req_token = httpx.post(
|
||||||
|
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||||
|
data={
|
||||||
|
"client_id": azure_client_id,
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"scope": "https://cognitiveservices.azure.com/.default",
|
||||||
|
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||||
|
"client_assertion": oidc_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if req_token.status_code != 200:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=req_token.status_code,
|
||||||
|
message=req_token.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||||
|
|
||||||
|
if possible_azure_ad_token is None:
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=422, message="Azure AD Token not returned"
|
||||||
|
)
|
||||||
|
|
||||||
|
return possible_azure_ad_token
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["api-key"] = api_key
|
headers["api-key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
@ -151,7 +200,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout,
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params,
|
optional_params,
|
||||||
litellm_params,
|
litellm_params,
|
||||||
|
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
|
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
|
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if aimg_generation == True:
|
if aimg_generation == True:
|
||||||
|
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
|
if azure_ad_token.startswith("oidc/"):
|
||||||
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if max_retries is not None:
|
if max_retries is not None:
|
||||||
|
@ -952,6 +1018,81 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_headers(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
timeout: float,
|
||||||
|
mode: str,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
input: Optional[list] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
client_session = litellm.client_session or httpx.Client(
|
||||||
|
transport=CustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
|
)
|
||||||
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
## build base url - assume api base includes resource name
|
||||||
|
if not api_base.endswith("/"):
|
||||||
|
api_base += "/"
|
||||||
|
api_base += f"{model}"
|
||||||
|
client = AzureOpenAI(
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
model = None
|
||||||
|
# cloudflare ai gateway, needs model=None
|
||||||
|
else:
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# only run this check if it's not cloudflare ai gateway
|
||||||
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
completion = None
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hey"}]
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
raise Exception("invalid completion response")
|
||||||
|
|
||||||
|
if (
|
||||||
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def ahealth_check(
|
async def ahealth_check(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
@ -963,7 +1104,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
):
|
) -> dict:
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
)
|
)
|
||||||
|
@ -1040,4 +1181,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
"x-ratelimit-remaining-tokens"
|
"x-ratelimit-remaining-tokens"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, Union, Any
|
from typing import Optional, Union, Any
|
||||||
import types, requests
|
import types, requests # type: ignore
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
|
@ -4,7 +4,13 @@ from enum import Enum
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional, Any, Union, List
|
from typing import Callable, Optional, Any, Union, List
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
get_secret,
|
||||||
|
Usage,
|
||||||
|
ImageResponse,
|
||||||
|
map_finish_reason,
|
||||||
|
)
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -157,6 +163,7 @@ class AmazonAnthropicClaude3Config:
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
@ -524,6 +531,17 @@ class AmazonStabilityConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_header(headers):
|
||||||
|
"""Closure to capture the headers and add them."""
|
||||||
|
|
||||||
|
def callback(request, **kwargs):
|
||||||
|
"""Actual callback function that Boto3 will call."""
|
||||||
|
for header_name, header_value in headers.items():
|
||||||
|
request.headers.add_header(header_name, header_value)
|
||||||
|
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
def init_bedrock_client(
|
def init_bedrock_client(
|
||||||
region_name=None,
|
region_name=None,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -533,12 +551,13 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
timeout: Optional[int] = None,
|
aws_web_identity_token: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
# Define the list of parameters to check
|
# Define the list of parameters to check
|
||||||
params_to_check = [
|
params_to_check = [
|
||||||
|
@ -549,6 +568,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Iterate over parameters and update if needed
|
# Iterate over parameters and update if needed
|
||||||
|
@ -564,6 +584,7 @@ def init_bedrock_client(
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
aws_role_name,
|
aws_role_name,
|
||||||
|
aws_web_identity_token,
|
||||||
) = params_to_check
|
) = params_to_check
|
||||||
|
|
||||||
### SET REGION NAME
|
### SET REGION NAME
|
||||||
|
@ -592,10 +613,48 @@ def init_bedrock_client(
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
if isinstance(timeout, float):
|
||||||
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||||
|
elif isinstance(timeout, httpx.Timeout):
|
||||||
|
config = boto3.session.Config(
|
||||||
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = boto3.session.Config()
|
||||||
|
|
||||||
### CHECK STS ###
|
### CHECK STS ###
|
||||||
if aws_role_name is not None and aws_session_name is not None:
|
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||||
|
oidc_token = get_secret(aws_web_identity_token)
|
||||||
|
|
||||||
|
if oidc_token is None:
|
||||||
|
raise BedrockError(
|
||||||
|
message="OIDC token could not be retrieved from secret manager.",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||||
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||||
|
sts_response = sts_client.assume_role_with_web_identity(
|
||||||
|
RoleArn=aws_role_name,
|
||||||
|
RoleSessionName=aws_session_name,
|
||||||
|
WebIdentityToken=oidc_token,
|
||||||
|
DurationSeconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||||
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||||
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||||
|
region_name=region_name,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
elif aws_role_name is not None and aws_session_name is not None:
|
||||||
# use sts if role name passed in
|
# use sts if role name passed in
|
||||||
sts_client = boto3.client(
|
sts_client = boto3.client(
|
||||||
"sts",
|
"sts",
|
||||||
|
@ -647,6 +706,10 @@ def init_bedrock_client(
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
if extra_headers:
|
||||||
|
client.meta.events.register(
|
||||||
|
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||||
|
)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -710,6 +773,7 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
@ -725,6 +789,7 @@ def completion(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = optional_params.pop("aws_bedrock_client", None)
|
client = optional_params.pop("aws_bedrock_client", None)
|
||||||
|
@ -739,6 +804,8 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1043,7 +1110,9 @@ def completion(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response["finish_reason"] = response_body["stop_reason"]
|
model_response["finish_reason"] = map_finish_reason(
|
||||||
|
response_body["stop_reason"]
|
||||||
|
)
|
||||||
_usage = litellm.Usage(
|
_usage = litellm.Usage(
|
||||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||||
completion_tokens=response_body["usage"]["output_tokens"],
|
completion_tokens=response_body["usage"]["output_tokens"],
|
||||||
|
@ -1194,7 +1263,7 @@ def _embedding_func_single(
|
||||||
"input_type", "search_document"
|
"input_type", "search_document"
|
||||||
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||||
data = {"texts": [input], **inference_params} # type: ignore
|
data = {"texts": [input], **inference_params} # type: ignore
|
||||||
body = json.dumps(data).encode("utf-8")
|
body = json.dumps(data).encode("utf-8") # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
|
@ -1258,6 +1327,7 @@ def embedding(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1265,6 +1335,7 @@ def embedding(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
)
|
)
|
||||||
|
@ -1347,6 +1418,7 @@ def image_generation(
|
||||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
"aws_bedrock_runtime_endpoint", None
|
"aws_bedrock_runtime_endpoint", None
|
||||||
)
|
)
|
||||||
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||||
|
|
||||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||||
client = init_bedrock_client(
|
client = init_bedrock_client(
|
||||||
|
@ -1354,6 +1426,7 @@ def image_generation(
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -1386,7 +1459,7 @@ def image_generation(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
body={body},
|
body={body}, # type: ignore
|
||||||
modelId={modelId},
|
modelId={modelId},
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class CohereError(Exception):
|
class CohereError(Exception):
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import cohere_message_pt
|
from .prompt_templates.factory import cohere_message_pt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable, Dict, List, Any
|
from typing import Callable, Dict, List, Any, Literal
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceError(Exception):
|
class HuggingfaceError(Exception):
|
||||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
hf_task_list = [
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
hf_tasks = Literal[
|
||||||
|
"text-generation-inference",
|
||||||
|
"conversational",
|
||||||
|
"text-classification",
|
||||||
|
"text-generation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceConfig:
|
class HuggingfaceConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
hf_task: Optional[hf_tasks] = (
|
||||||
|
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||||
|
)
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
decoder_input_details: Optional[bool] = None
|
decoder_input_details: Optional[bool] = None
|
||||||
details: Optional[bool] = True # enables returning logprobs + best of
|
details: Optional[bool] = True # enables returning logprobs + best of
|
||||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"stream",
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"n",
|
||||||
|
"echo",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||||
|
if param == "temperature":
|
||||||
|
if value == 0.0 or value == 0:
|
||||||
|
# hugging face exception raised when temp==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||||
|
value = 0.01
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "n":
|
||||||
|
optional_params["best_of"] = value
|
||||||
|
optional_params["do_sample"] = (
|
||||||
|
True # Need to sample if you want best of for hf inference endpoints
|
||||||
|
)
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
# HF TGI raises the following exception when max_new_tokens==0
|
||||||
|
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||||
|
if value == 0:
|
||||||
|
value = 1
|
||||||
|
optional_params["max_new_tokens"] = value
|
||||||
|
if param == "echo":
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||||
|
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||||
|
optional_params["decoder_input_details"] = True
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
"""
|
"""
|
||||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_for_model(model):
|
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
return model.split("/")[0] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference"
|
||||||
elif model in conversational_models:
|
elif model in conversational_models:
|
||||||
return "conversational"
|
return "conversational"
|
||||||
elif "roneneldan/TinyStories" in model:
|
elif "roneneldan/TinyStories" in model:
|
||||||
return None
|
return "text-generation"
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference" # default to tgi
|
return "text-generation-inference" # default to tgi
|
||||||
|
|
||||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
||||||
self,
|
self,
|
||||||
completion_response,
|
completion_response,
|
||||||
model_response,
|
model_response,
|
||||||
task,
|
task: hf_tasks,
|
||||||
optional_params,
|
optional_params,
|
||||||
encoding,
|
encoding,
|
||||||
input_text,
|
input_text,
|
||||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
choices_list.append(choice_obj)
|
choices_list.append(choice_obj)
|
||||||
model_response["choices"].extend(choices_list)
|
model_response["choices"].extend(choices_list)
|
||||||
|
elif task == "text-classification":
|
||||||
|
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||||
|
completion_response
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||||
|
@ -322,9 +393,9 @@ class Huggingface(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
task = get_hf_task_for_model(model)
|
task = get_hf_task_for_model(model)
|
||||||
|
## VALIDATE API FORMAT
|
||||||
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
|
raise Exception(
|
||||||
|
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||||
|
)
|
||||||
|
|
||||||
print_verbose(f"{model}, {task}")
|
print_verbose(f"{model}, {task}")
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
@ -399,10 +476,11 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": (
|
"stream": ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and isinstance(optional_params["stream"], bool)
|
||||||
|
and optional_params["stream"] == True # type: ignore
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -432,14 +510,15 @@ class Huggingface(BaseLLM):
|
||||||
inference_params.pop("return_full_text")
|
inference_params.pop("return_full_text")
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
}
|
||||||
"stream": (
|
if task == "text-generation-inference":
|
||||||
|
data["parameters"] = inference_params
|
||||||
|
data["stream"] = ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
else False
|
else False
|
||||||
),
|
)
|
||||||
}
|
|
||||||
input_text = prompt
|
input_text = prompt
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -530,10 +609,10 @@ class Huggingface(BaseLLM):
|
||||||
isinstance(completion_response, dict)
|
isinstance(completion_response, dict)
|
||||||
and "error" in completion_response
|
and "error" in completion_response
|
||||||
):
|
):
|
||||||
print_verbose(f"completion error: {completion_response['error']}")
|
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||||
print_verbose(f"response.status_code: {response.status_code}")
|
print_verbose(f"response.status_code: {response.status_code}")
|
||||||
raise HuggingfaceError(
|
raise HuggingfaceError(
|
||||||
message=completion_response["error"],
|
message=completion_response["error"], # type: ignore
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
return self.convert_to_model_response_object(
|
return self.convert_to_model_response_object(
|
||||||
|
@ -562,7 +641,7 @@ class Huggingface(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
task: str,
|
task: hf_tasks,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, traceback
|
import time, traceback
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import requests, types, time
|
from itertools import chain
|
||||||
|
import requests, types, time # type: ignore
|
||||||
import json, uuid
|
import json, uuid
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx, aiohttp, asyncio
|
import httpx, aiohttp, asyncio # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -212,25 +213,31 @@ def get_ollama_response(
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
model_response["choices"][0]["finish_reason"] = "stop"
|
model_response["choices"][0]["finish_reason"] = "stop"
|
||||||
if optional_params.get("format", "") == "json":
|
if data.get("format", "") == "json":
|
||||||
function_call = json.loads(response_json["response"])
|
function_call = json.loads(response_json["response"])
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + model
|
model_response["model"] = "ollama/" + model
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
||||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
completion_tokens = response_json.get(
|
||||||
|
"eval_count", len(response_json.get("message", dict()).get("content", ""))
|
||||||
|
)
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
@ -255,8 +262,37 @@ def ollama_completion_stream(url, data, logging_obj):
|
||||||
custom_llm_provider="ollama",
|
custom_llm_provider="ollama",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
for transformed_chunk in streamwrapper:
|
# If format is JSON, this was a function call
|
||||||
yield transformed_chunk
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = next(streamwrapper)
|
||||||
|
response_content = "".join(
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
for chunk in chain([first_chunk], streamwrapper)
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
|
for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -278,8 +314,40 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
||||||
custom_llm_provider="ollama",
|
custom_llm_provider="ollama",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
async for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = await anext(streamwrapper)
|
||||||
|
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||||
|
response_content = first_chunk_content + "".join(
|
||||||
|
[
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
async for chunk in streamwrapper
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
]
|
||||||
|
)
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
|
@ -317,12 +385,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
{
|
{
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"]["content"] = response_json[
|
model_response["choices"][0]["message"]["content"] = response_json[
|
||||||
"response"
|
"response"
|
||||||
|
@ -330,7 +402,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = "ollama/" + data["model"]
|
model_response["model"] = "ollama/" + data["model"]
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
||||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
completion_tokens = response_json.get(
|
||||||
|
"eval_count",
|
||||||
|
len(response_json.get("message", dict()).get("content", "")),
|
||||||
|
)
|
||||||
model_response["usage"] = litellm.Usage(
|
model_response["usage"] = litellm.Usage(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
@ -417,3 +492,25 @@ async def ollama_aembeddings(
|
||||||
"total_tokens": total_input_tokens,
|
"total_tokens": total_input_tokens,
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_embeddings(
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
prompts: list,
|
||||||
|
optional_params=None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
return asyncio.run(
|
||||||
|
ollama_aembeddings(
|
||||||
|
api_base,
|
||||||
|
model,
|
||||||
|
prompts,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model_response,
|
||||||
|
encoding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from itertools import chain
|
||||||
import requests, types, time
|
import requests, types, time
|
||||||
import json, uuid
|
import json, uuid
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -297,6 +298,7 @@ def get_ollama_response(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"] = response_json["message"]
|
model_response["choices"][0]["message"] = response_json["message"]
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
|
@ -335,8 +337,35 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
custom_llm_provider="ollama_chat",
|
custom_llm_provider="ollama_chat",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = next(streamwrapper)
|
||||||
|
response_content = "".join(
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
for chunk in chain([first_chunk], streamwrapper)
|
||||||
|
if chunk.choices[0].delta.content
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
|
for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -366,8 +395,36 @@ async def ollama_async_streaming(
|
||||||
custom_llm_provider="ollama_chat",
|
custom_llm_provider="ollama_chat",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
async for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
# If format is JSON, this was a function call
|
||||||
|
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||||
|
if data.get("format", "") == "json":
|
||||||
|
first_chunk = await anext(streamwrapper)
|
||||||
|
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||||
|
response_content = first_chunk_content + "".join(
|
||||||
|
[
|
||||||
|
chunk.choices[0].delta.content
|
||||||
|
async for chunk in streamwrapper
|
||||||
|
if chunk.choices[0].delta.content]
|
||||||
|
)
|
||||||
|
function_call = json.loads(response_content)
|
||||||
|
delta = litellm.utils.Delta(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response = first_chunk
|
||||||
|
model_response["choices"][0]["delta"] = delta
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
|
yield model_response
|
||||||
|
else:
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
@ -425,6 +482,7 @@ async def ollama_acompletion(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
model_response["choices"][0]["message"] = message
|
model_response["choices"][0]["message"] = message
|
||||||
|
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response["choices"][0]["message"] = response_json["message"]
|
model_response["choices"][0]["message"] = response_json["message"]
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
|
@ -1,4 +1,13 @@
|
||||||
from typing import Optional, Union, Any, BinaryIO
|
from typing import (
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Literal,
|
||||||
|
Iterable,
|
||||||
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
from pydantic import BaseModel
|
||||||
import types, time, json, traceback
|
import types, time, json, traceback
|
||||||
import httpx
|
import httpx
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -13,10 +22,10 @@ from litellm.utils import (
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
)
|
)
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import aiohttp, requests
|
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from ..types.llms.openai import *
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(Exception):
|
class OpenAIError(Exception):
|
||||||
|
@ -246,7 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
|
@ -271,9 +280,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
|
|
||||||
if not isinstance(timeout, float):
|
if not isinstance(timeout, float) and not isinstance(
|
||||||
|
timeout, httpx.Timeout
|
||||||
|
):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=422, message=f"Timeout needs to be a float"
|
status_code=422,
|
||||||
|
message=f"Timeout needs to be a float or httpx.Timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider != "openai":
|
if custom_llm_provider != "openai":
|
||||||
|
@ -425,7 +437,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
@ -480,7 +492,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
def streaming(
|
def streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -518,13 +530,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: float,
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -567,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
except (
|
except (
|
||||||
|
@ -1191,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in streamwrapper:
|
for chunk in streamwrapper:
|
||||||
|
@ -1229,7 +1244,228 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
yield transformed_chunk
|
yield transformed_chunk
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_openai_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> OpenAI:
|
||||||
|
received_args = locals()
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["base_url"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
openai_client = OpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
return openai_client
|
||||||
|
|
||||||
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.assistants.list()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### MESSAGES ###
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: MessageData,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
|
||||||
|
thread_id, **message_data
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj: Optional[OpenAIMessage] = None
|
||||||
|
if getattr(thread_message, "status", None) is None:
|
||||||
|
thread_message.status = "completed"
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
else:
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
return response_obj
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI] = None,
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### THREADS ###
|
||||||
|
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
) -> Thread:
|
||||||
|
"""
|
||||||
|
Here's an example:
|
||||||
|
```
|
||||||
|
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
|
||||||
|
|
||||||
|
# create thread
|
||||||
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||||
|
openai_api.create_thread(messages=[message])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if messages is not None:
|
||||||
|
data["messages"] = messages # type: ignore
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata # type: ignore
|
||||||
|
|
||||||
|
message_thread = openai_client.beta.threads.create(**data) # type: ignore
|
||||||
|
|
||||||
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
|
||||||
|
return Thread(**response.dict())
|
||||||
|
|
||||||
|
def delete_thread(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
### RUNS ###
|
||||||
|
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
) -> Run:
|
||||||
|
openai_client = self.get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.runs.create_and_poll(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
|
520
litellm/llms/predibase.py
Normal file
520
litellm/llms/predibase.py
Normal file
|
@ -0,0 +1,520 @@
|
||||||
|
# What is this?
|
||||||
|
## Controller file for Predibase Integration - https://predibase.com/
|
||||||
|
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List, Literal, Union
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
if request is not None:
|
||||||
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
adapter_id: Optional[str] = None
|
||||||
|
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: bool = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: int = (
|
||||||
|
256 # openai default - requests hang if max_new_tokens not given
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[bool] = (
|
||||||
|
False # by default don't return the input as part of the output
|
||||||
|
)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if user_headers is not None and isinstance(user_headers, dict):
|
||||||
|
headers = {**headers, **user_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def output_parser(self, generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = [
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: dict,
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise PredibaseError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise PredibaseError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not isinstance(completion_response, dict)
|
||||||
|
or "generated_text" not in completion_response
|
||||||
|
):
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"response is not in expected format - {completion_response}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(completion_response["generated_text"]) > 0:
|
||||||
|
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||||
|
completion_response["generated_text"]
|
||||||
|
)
|
||||||
|
## GETTING LOGPROBS + FINISH REASON
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "tokens" in completion_response["details"]
|
||||||
|
):
|
||||||
|
model_response.choices[0].finish_reason = completion_response[
|
||||||
|
"details"
|
||||||
|
]["finish_reason"]
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response["details"]["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
model_response["choices"][0][
|
||||||
|
"message"
|
||||||
|
]._logprob = (
|
||||||
|
sum_logprob # [TODO] move this to using the actual logprobs
|
||||||
|
)
|
||||||
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "best_of_sequences" in completion_response["details"]
|
||||||
|
):
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(
|
||||||
|
completion_response["details"]["best_of_sequences"]
|
||||||
|
):
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in item["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
if len(item["generated_text"]) > 0:
|
||||||
|
message_obj = Message(
|
||||||
|
content=self.output_parser(item["generated_text"]),
|
||||||
|
logprobs=sum_logprob,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"].extend(choices_list)
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = 0
|
||||||
|
try:
|
||||||
|
prompt_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||||
|
) ##[TODO] use a model-specific tokenizer here
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
if output_text is not None and len(output_text) > 0:
|
||||||
|
completion_tokens = 0
|
||||||
|
try:
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
) ##[TODO] use a model-specific tokenizer
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage # type: ignore
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: str,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
tenant_id: str,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
headers = self._validate_environment(api_key, headers)
|
||||||
|
completion_url = ""
|
||||||
|
input_text = ""
|
||||||
|
base_url = "https://serving.app.predibase.com"
|
||||||
|
if "https" in model:
|
||||||
|
completion_url = model
|
||||||
|
elif api_base:
|
||||||
|
base_url = api_base
|
||||||
|
elif "PREDIBASE_API_BASE" in os.environ:
|
||||||
|
base_url = os.getenv("PREDIBASE_API_BASE", "")
|
||||||
|
|
||||||
|
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||||
|
|
||||||
|
if optional_params.get("stream", False) == True:
|
||||||
|
completion_url += "/generate_stream"
|
||||||
|
else:
|
||||||
|
completion_url += "/generate"
|
||||||
|
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.PredibaseConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_text,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if acompletion is True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### SYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
response = requests.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
_response = CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
logging_obj=logging_obj, # type: ignore
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> ModelResponse:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data: dict,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
data["stream"] = True
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
pass
|
|
@ -12,6 +12,16 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.types.completion import (
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionFunctionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.anthropic import *
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
|
@ -22,6 +32,41 @@ def prompt_injection_detection_default_pt():
|
||||||
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
||||||
|
|
||||||
|
|
||||||
|
def map_system_message_pt(messages: list) -> list:
|
||||||
|
"""
|
||||||
|
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
|
||||||
|
|
||||||
|
Enabled via `completion(...,supports_system_message=False)`
|
||||||
|
|
||||||
|
If next message is a user message or assistant message -> merge system prompt into it
|
||||||
|
|
||||||
|
if next message is system -> append a user message instead of the system message
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for i, m in enumerate(messages):
|
||||||
|
if m["role"] == "system":
|
||||||
|
if i < len(messages) - 1: # Not the last message
|
||||||
|
next_m = messages[i + 1]
|
||||||
|
next_role = next_m["role"]
|
||||||
|
if (
|
||||||
|
next_role == "user" or next_role == "assistant"
|
||||||
|
): # Next message is a user or assistant message
|
||||||
|
# Merge system prompt into the next message
|
||||||
|
next_m["content"] = m["content"] + " " + next_m["content"]
|
||||||
|
elif next_role == "system": # Next message is a system message
|
||||||
|
# Append a user message instead of the system message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Last message
|
||||||
|
new_message = {"role": "user", "content": m["content"]}
|
||||||
|
new_messages.append(new_message)
|
||||||
|
else: # Not a system message
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
# alpaca prompt template - for models like mythomax, etc.
|
# alpaca prompt template - for models like mythomax, etc.
|
||||||
def alpaca_pt(messages):
|
def alpaca_pt(messages):
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
|
@ -805,6 +850,13 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"content": "function result goes here",
|
"content": "function result goes here",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
OpenAI message with a function call result looks like:
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": "function result goes here",
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -821,18 +873,42 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
tool_call_id = message.get("tool_call_id")
|
if message["role"] == "tool":
|
||||||
content = message.get("content")
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
# We can't determine from openai message format whether it's a successful or
|
# We can't determine from openai message format whether it's a successful or
|
||||||
# error call result so default to the successful result template
|
# error call result so default to the successful result template
|
||||||
anthropic_tool_result = {
|
anthropic_tool_result = {
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": tool_call_id,
|
"tool_use_id": tool_call_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
return anthropic_tool_result
|
||||||
|
elif message["role"] == "function":
|
||||||
|
content = message.get("content")
|
||||||
|
anthropic_tool_result = {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": str(uuid.uuid4()),
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
return anthropic_tool_result
|
||||||
|
return {}
|
||||||
|
|
||||||
return anthropic_tool_result
|
|
||||||
|
def convert_function_to_anthropic_tool_invoke(function_call):
|
||||||
|
try:
|
||||||
|
anthropic_tool_invoke = [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"name": get_attribute_or_key(function_call, "name"),
|
||||||
|
"input": json.loads(get_attribute_or_key(function_call, "arguments")),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return anthropic_tool_invoke
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||||
|
@ -895,7 +971,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||||
def anthropic_messages_pt(messages: list):
|
def anthropic_messages_pt(messages: list):
|
||||||
"""
|
"""
|
||||||
format messages for anthropic
|
format messages for anthropic
|
||||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
1. Anthropic supports roles like "user" and "assistant" (system prompt sent separately)
|
||||||
2. The first message always needs to be of role "user"
|
2. The first message always needs to be of role "user"
|
||||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||||
|
@ -903,12 +979,14 @@ def anthropic_messages_pt(messages: list):
|
||||||
6. Ensure we only accept role, content. (message.name is not supported)
|
6. Ensure we only accept role, content. (message.name is not supported)
|
||||||
"""
|
"""
|
||||||
# add role=tool support to allow function call result/error submission
|
# add role=tool support to allow function call result/error submission
|
||||||
user_message_types = {"user", "tool"}
|
user_message_types = {"user", "tool", "function"}
|
||||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
||||||
new_messages = []
|
new_messages: list = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
tool_use_param = False
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content = []
|
user_content = []
|
||||||
|
init_msg_i = msg_i
|
||||||
## MERGE CONSECUTIVE USER CONTENT ##
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||||
if isinstance(messages[msg_i]["content"], list):
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
|
@ -924,7 +1002,10 @@ def anthropic_messages_pt(messages: list):
|
||||||
)
|
)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
user_content.append({"type": "text", "text": m["text"]})
|
user_content.append({"type": "text", "text": m["text"]})
|
||||||
elif messages[msg_i]["role"] == "tool":
|
elif (
|
||||||
|
messages[msg_i]["role"] == "tool"
|
||||||
|
or messages[msg_i]["role"] == "function"
|
||||||
|
):
|
||||||
# OpenAI's tool message content will always be a string
|
# OpenAI's tool message content will always be a string
|
||||||
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
||||||
else:
|
else:
|
||||||
|
@ -953,11 +1034,24 @@ def anthropic_messages_pt(messages: list):
|
||||||
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if messages[msg_i].get("function_call"):
|
||||||
|
assistant_content.extend(
|
||||||
|
convert_function_to_anthropic_tool_invoke(
|
||||||
|
messages[msg_i]["function_call"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise Exception(
|
||||||
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
messages[msg_i]
|
||||||
|
)
|
||||||
|
)
|
||||||
if not new_messages or new_messages[0]["role"] != "user":
|
if not new_messages or new_messages[0]["role"] != "user":
|
||||||
if litellm.modify_params:
|
if litellm.modify_params:
|
||||||
new_messages.insert(
|
new_messages.insert(
|
||||||
|
@ -969,11 +1063,14 @@ def anthropic_messages_pt(messages: list):
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_messages[-1]["role"] == "assistant":
|
if new_messages[-1]["role"] == "assistant":
|
||||||
for content in new_messages[-1]["content"]:
|
if isinstance(new_messages[-1]["content"], str):
|
||||||
if isinstance(content, dict) and content["type"] == "text":
|
new_messages[-1]["content"] = new_messages[-1]["content"].rstrip()
|
||||||
content["text"] = content[
|
elif isinstance(new_messages[-1]["content"], list):
|
||||||
"text"
|
for content in new_messages[-1]["content"]:
|
||||||
].rstrip() # no trailing whitespace for final assistant message
|
if isinstance(content, dict) and content["type"] == "text":
|
||||||
|
content["text"] = content[
|
||||||
|
"text"
|
||||||
|
].rstrip() # no trailing whitespace for final assistant message
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
import os, types, traceback
|
import os, types, traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Any
|
from typing import Callable, Optional, Any
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ def completion(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
InferenceComponentName={model_id},
|
InferenceComponentName={model_id},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)
|
)
|
||||||
""" # type: ignore
|
""" # type: ignore
|
||||||
|
@ -321,7 +321,7 @@ def completion(
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)
|
)
|
||||||
""" # type: ignore
|
""" # type: ignore
|
||||||
|
@ -688,7 +688,7 @@ def embedding(
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
EndpointName={model},
|
EndpointName={model},
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body={data},
|
Body={data}, # type: ignore
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)""" # type: ignore
|
)""" # type: ignore
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
|
|
@ -6,11 +6,11 @@ Reference: https://docs.together.ai/docs/openai-api-compatibility
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List
|
||||||
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TritonError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class TritonChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_base: str,
|
||||||
|
logging_obj=None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TritonError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
_text_response = response.text
|
||||||
|
|
||||||
|
logging_obj.post_call(original_response=_text_response)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
_outputs = _json_response["outputs"]
|
||||||
|
_output_data = _outputs[0]["data"]
|
||||||
|
_embedding_output = {
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": _output_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_response.model = _json_response.get("model_name", "None")
|
||||||
|
model_response.data = [_embedding_output]
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
api_base: str,
|
||||||
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aembedding=None,
|
||||||
|
):
|
||||||
|
data_for_triton = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"shape": [1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": input,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
|
||||||
|
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||||
|
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input="",
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": curl_string,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(
|
||||||
|
data=data_for_triton,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||||
|
)
|
|
@ -1,12 +1,12 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List
|
from typing import Callable, Optional, Union, List
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx, inspect
|
import httpx, inspect # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -419,6 +419,7 @@ def completion(
|
||||||
from google.protobuf.struct_pb2 import Value # type: ignore
|
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
import proto # type: ignore
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -605,9 +606,21 @@ def completion(
|
||||||
):
|
):
|
||||||
function_call = response.candidates[0].content.parts[0].function_call
|
function_call = response.candidates[0].content.parts[0].function_call
|
||||||
args_dict = {}
|
args_dict = {}
|
||||||
for k, v in function_call.args.items():
|
|
||||||
args_dict[k] = v
|
# Check if it's a RepeatedComposite instance
|
||||||
args_str = json.dumps(args_dict)
|
for key, val in function_call.args.items():
|
||||||
|
if isinstance(
|
||||||
|
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||||
|
):
|
||||||
|
# If so, convert to list
|
||||||
|
args_dict[key] = [v for v in val]
|
||||||
|
else:
|
||||||
|
args_dict[key] = val
|
||||||
|
|
||||||
|
try:
|
||||||
|
args_str = json.dumps(args_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=422, message=str(e))
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
|
@ -810,6 +823,8 @@ def completion(
|
||||||
setattr(model_response, "usage", usage)
|
setattr(model_response, "usage", usage)
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if isinstance(e, VertexAIError):
|
||||||
|
raise e
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy
|
import requests, copy # type: ignore
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
|
@ -17,7 +17,7 @@ from .prompt_templates.factory import (
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import time, httpx
|
import time, httpx # type: ignore
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
|
@ -3,8 +3,8 @@ import json, types, time # noqa: E401
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, Dict, Optional, Any, Union, List
|
from typing import Callable, Dict, Optional, Any, Union, List
|
||||||
|
|
||||||
import httpx
|
import httpx # type: ignore
|
||||||
import requests
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage
|
from litellm.utils import ModelResponse, get_secret, Usage
|
||||||
|
|
||||||
|
|
180
litellm/main.py
180
litellm/main.py
|
@ -14,7 +14,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
|
@ -34,9 +33,12 @@ from litellm.utils import (
|
||||||
async_mock_completion_streaming_obj,
|
async_mock_completion_streaming_obj,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
token_counter,
|
token_counter,
|
||||||
|
create_pretrained_tokenizer,
|
||||||
|
create_tokenizer,
|
||||||
Usage,
|
Usage,
|
||||||
get_optional_params_embeddings,
|
get_optional_params_embeddings,
|
||||||
get_optional_params_image_gen,
|
get_optional_params_image_gen,
|
||||||
|
supports_httpx_timeout,
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic_text,
|
anthropic_text,
|
||||||
|
@ -44,6 +46,7 @@ from .llms import (
|
||||||
ai21,
|
ai21,
|
||||||
sagemaker,
|
sagemaker,
|
||||||
bedrock,
|
bedrock,
|
||||||
|
triton,
|
||||||
huggingface_restapi,
|
huggingface_restapi,
|
||||||
replicate,
|
replicate,
|
||||||
aleph_alpha,
|
aleph_alpha,
|
||||||
|
@ -71,10 +74,13 @@ from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.anthropic import AnthropicChatCompletion
|
from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
|
map_system_message_pt,
|
||||||
)
|
)
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
@ -106,6 +112,8 @@ anthropic_text_completions = AnthropicTextCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
|
triton_chat_completions = TritonChatCompletion()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,6 +192,7 @@ async def acompletion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -203,6 +212,7 @@ async def acompletion(
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -220,6 +230,7 @@ async def acompletion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -257,6 +268,7 @@ async def acompletion(
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
|
"stream_options": stream_options,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"presence_penalty": presence_penalty,
|
"presence_penalty": presence_penalty,
|
||||||
|
@ -301,6 +313,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
|
@ -309,6 +322,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
|
or custom_llm_provider == "predibase"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -448,11 +462,12 @@ def completion(
|
||||||
model: str,
|
model: str,
|
||||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
messages: List = [],
|
messages: List = [],
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -492,6 +507,7 @@ def completion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -551,6 +567,7 @@ def completion(
|
||||||
eos_token = kwargs.get("eos_token", None)
|
eos_token = kwargs.get("eos_token", None)
|
||||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||||
hf_model_name = kwargs.get("hf_model_name", None)
|
hf_model_name = kwargs.get("hf_model_name", None)
|
||||||
|
supports_system_message = kwargs.get("supports_system_message", None)
|
||||||
### TEXT COMPLETION CALLS ###
|
### TEXT COMPLETION CALLS ###
|
||||||
text_completion = kwargs.get("text_completion", False)
|
text_completion = kwargs.get("text_completion", False)
|
||||||
atext_completion = kwargs.get("atext_completion", False)
|
atext_completion = kwargs.get("atext_completion", False)
|
||||||
|
@ -568,6 +585,7 @@ def completion(
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
@ -616,6 +634,7 @@ def completion(
|
||||||
"model_list",
|
"model_list",
|
||||||
"num_retries",
|
"num_retries",
|
||||||
"context_window_fallback_dict",
|
"context_window_fallback_dict",
|
||||||
|
"retry_policy",
|
||||||
"roles",
|
"roles",
|
||||||
"final_prompt_value",
|
"final_prompt_value",
|
||||||
"bos_token",
|
"bos_token",
|
||||||
|
@ -641,16 +660,30 @@ def completion(
|
||||||
"no-log",
|
"no-log",
|
||||||
"base_model",
|
"base_model",
|
||||||
"stream_timeout",
|
"stream_timeout",
|
||||||
|
"supports_system_message",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
if timeout is None:
|
|
||||||
timeout = (
|
### TIMEOUT LOGIC ###
|
||||||
kwargs.get("request_timeout", None) or 600
|
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
) # set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
timeout = float(timeout)
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if base_url is not None:
|
if base_url is not None:
|
||||||
api_base = base_url
|
api_base = base_url
|
||||||
|
@ -745,6 +778,13 @@ def completion(
|
||||||
custom_prompt_dict[model]["bos_token"] = bos_token
|
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||||
if eos_token:
|
if eos_token:
|
||||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||||
|
|
||||||
|
if (
|
||||||
|
supports_system_message is not None
|
||||||
|
and isinstance(supports_system_message, bool)
|
||||||
|
and supports_system_message == False
|
||||||
|
):
|
||||||
|
messages = map_system_message_pt(messages=messages)
|
||||||
model_api_key = get_api_key(
|
model_api_key = get_api_key(
|
||||||
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
||||||
) # get the api key from the environment if required for the model
|
) # get the api key from the environment if required for the model
|
||||||
|
@ -759,6 +799,7 @@ def completion(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
n=n,
|
n=n,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
@ -871,7 +912,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -958,6 +999,7 @@ def completion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
|
@ -1012,7 +1054,7 @@ def completion(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
organization=organization,
|
organization=organization,
|
||||||
|
@ -1097,7 +1139,7 @@ def completion(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1471,7 +1513,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -1564,7 +1606,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -1749,6 +1791,52 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "predibase":
|
||||||
|
tenant_id = (
|
||||||
|
optional_params.pop("tenant_id", None)
|
||||||
|
or optional_params.pop("predibase_tenant_id", None)
|
||||||
|
or litellm.predibase_tenant_id
|
||||||
|
or get_secret("PREDIBASE_TENANT_ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
optional_params.pop("api_base", None)
|
||||||
|
or optional_params.pop("base_url", None)
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("PREDIBASE_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.predibase_key
|
||||||
|
or get_secret("PREDIBASE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_response = predibase_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
api_key=api_key,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and acompletion == False
|
||||||
|
):
|
||||||
|
return _model_response
|
||||||
|
response = _model_response
|
||||||
elif custom_llm_provider == "ai21":
|
elif custom_llm_provider == "ai21":
|
||||||
custom_llm_provider = "ai21"
|
custom_llm_provider = "ai21"
|
||||||
ai21_key = (
|
ai21_key = (
|
||||||
|
@ -1844,6 +1932,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1891,7 +1980,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
timeout=timeout,
|
timeout=timeout, # type: ignore
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -2143,7 +2232,7 @@ def completion(
|
||||||
"""
|
"""
|
||||||
assume input to custom LLM api bases follow this format:
|
assume input to custom LLM api bases follow this format:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
api_base,
|
api_base,
|
||||||
json={
|
json={
|
||||||
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
||||||
'params': {
|
'params': {
|
||||||
|
@ -2272,7 +2361,7 @@ def batch_completion(
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[float] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
|
@ -2535,11 +2624,13 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "voyage"
|
or custom_llm_provider == "voyage"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "triton"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "openrouter"
|
or custom_llm_provider == "openrouter"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
@ -2665,6 +2756,7 @@ def embedding(
|
||||||
"model_list",
|
"model_list",
|
||||||
"num_retries",
|
"num_retries",
|
||||||
"context_window_fallback_dict",
|
"context_window_fallback_dict",
|
||||||
|
"retry_policy",
|
||||||
"roles",
|
"roles",
|
||||||
"final_prompt_value",
|
"final_prompt_value",
|
||||||
"bos_token",
|
"bos_token",
|
||||||
|
@ -2688,6 +2780,8 @@ def embedding(
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
"no-log",
|
"no-log",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -2864,23 +2958,43 @@ def embedding(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "triton":
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_base is required for triton. Please pass `api_base`"
|
||||||
|
)
|
||||||
|
response = triton_chat_completions.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
optional_params=optional_params,
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
optional_params.pop("vertex_project", None)
|
optional_params.pop("vertex_project", None)
|
||||||
or optional_params.pop("vertex_ai_project", None)
|
or optional_params.pop("vertex_ai_project", None)
|
||||||
or litellm.vertex_project
|
or litellm.vertex_project
|
||||||
or get_secret("VERTEXAI_PROJECT")
|
or get_secret("VERTEXAI_PROJECT")
|
||||||
|
or get_secret("VERTEX_PROJECT")
|
||||||
)
|
)
|
||||||
vertex_ai_location = (
|
vertex_ai_location = (
|
||||||
optional_params.pop("vertex_location", None)
|
optional_params.pop("vertex_location", None)
|
||||||
or optional_params.pop("vertex_ai_location", None)
|
or optional_params.pop("vertex_ai_location", None)
|
||||||
or litellm.vertex_location
|
or litellm.vertex_location
|
||||||
or get_secret("VERTEXAI_LOCATION")
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
|
or get_secret("VERTEX_LOCATION")
|
||||||
)
|
)
|
||||||
vertex_credentials = (
|
vertex_credentials = (
|
||||||
optional_params.pop("vertex_credentials", None)
|
optional_params.pop("vertex_credentials", None)
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
or get_secret("VERTEX_CREDENTIALS")
|
||||||
)
|
)
|
||||||
|
|
||||||
response = vertex_ai.embedding(
|
response = vertex_ai.embedding(
|
||||||
|
@ -2921,16 +3035,18 @@ def embedding(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
llm_provider="ollama", # type: ignore
|
llm_provider="ollama", # type: ignore
|
||||||
)
|
)
|
||||||
if aembedding:
|
ollama_embeddings_fn = (
|
||||||
response = ollama.ollama_aembeddings(
|
ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||||
api_base=api_base,
|
)
|
||||||
model=model,
|
response = ollama_embeddings_fn(
|
||||||
prompts=input,
|
api_base=api_base,
|
||||||
encoding=encoding,
|
model=model,
|
||||||
logging_obj=logging,
|
prompts=input,
|
||||||
optional_params=optional_params,
|
encoding=encoding,
|
||||||
model_response=EmbeddingResponse(),
|
logging_obj=logging,
|
||||||
)
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
response = sagemaker.embedding(
|
response = sagemaker.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3059,11 +3175,13 @@ async def atext_completion(*args, **kwargs):
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3094,6 +3212,8 @@ async def atext_completion(*args, **kwargs):
|
||||||
## TRANSLATE CHAT TO TEXT FORMAT ##
|
## TRANSLATE CHAT TO TEXT FORMAT ##
|
||||||
if isinstance(response, TextCompletionResponse):
|
if isinstance(response, TextCompletionResponse):
|
||||||
return response
|
return response
|
||||||
|
elif asyncio.iscoroutine(response):
|
||||||
|
response = await response
|
||||||
|
|
||||||
text_completion_response = TextCompletionResponse()
|
text_completion_response = TextCompletionResponse()
|
||||||
text_completion_response["id"] = response.get("id", None)
|
text_completion_response["id"] = response.get("id", None)
|
||||||
|
@ -3153,6 +3273,7 @@ def text_completion(
|
||||||
Union[str, List[str]]
|
Union[str, List[str]]
|
||||||
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
||||||
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
suffix: Optional[
|
suffix: Optional[
|
||||||
str
|
str
|
||||||
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
||||||
|
@ -3230,6 +3351,8 @@ def text_completion(
|
||||||
optional_params["stop"] = stop
|
optional_params["stop"] = stop
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if stream_options is not None:
|
||||||
|
optional_params["stream_options"] = stream_options
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
optional_params["suffix"] = suffix
|
optional_params["suffix"] = suffix
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
|
@ -3340,7 +3463,9 @@ def text_completion(
|
||||||
if kwargs.get("acompletion", False) == True:
|
if kwargs.get("acompletion", False) == True:
|
||||||
return response
|
return response
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
response = TextCompletionStreamWrapper(
|
||||||
|
completion_stream=response, model=model, stream_options=stream_options
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
|
@ -3534,6 +3659,7 @@ def image_generation(
|
||||||
"model_list",
|
"model_list",
|
||||||
"num_retries",
|
"num_retries",
|
||||||
"context_window_fallback_dict",
|
"context_window_fallback_dict",
|
||||||
|
"retry_policy",
|
||||||
"roles",
|
"roles",
|
||||||
"final_prompt_value",
|
"final_prompt_value",
|
||||||
"bos_token",
|
"bos_token",
|
||||||
|
@ -3554,6 +3680,8 @@ def image_generation(
|
||||||
"caching_groups",
|
"caching_groups",
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
|
|
@ -338,6 +338,18 @@
|
||||||
"output_cost_per_second": 0.0001,
|
"output_cost_per_second": 0.0001,
|
||||||
"litellm_provider": "azure"
|
"litellm_provider": "azure"
|
||||||
},
|
},
|
||||||
|
"azure/gpt-4-turbo-2024-04-09": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
|
},
|
||||||
"azure/gpt-4-0125-preview": {
|
"azure/gpt-4-0125-preview": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -727,6 +739,24 @@
|
||||||
"litellm_provider": "mistral",
|
"litellm_provider": "mistral",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"deepseek-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 32000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"deepseek-coder": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 16000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000014,
|
||||||
|
"output_cost_per_token": 0.00000028,
|
||||||
|
"litellm_provider": "deepseek",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"groq/llama2-70b-4096": {
|
"groq/llama2-70b-4096": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
@ -813,6 +843,7 @@
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true,
|
||||||
"tool_use_system_prompt_tokens": 264
|
"tool_use_system_prompt_tokens": 264
|
||||||
},
|
},
|
||||||
"claude-3-opus-20240229": {
|
"claude-3-opus-20240229": {
|
||||||
|
@ -824,6 +855,7 @@
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true,
|
||||||
"tool_use_system_prompt_tokens": 395
|
"tool_use_system_prompt_tokens": 395
|
||||||
},
|
},
|
||||||
"claude-3-sonnet-20240229": {
|
"claude-3-sonnet-20240229": {
|
||||||
|
@ -835,6 +867,7 @@
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true,
|
||||||
"tool_use_system_prompt_tokens": 159
|
"tool_use_system_prompt_tokens": 159
|
||||||
},
|
},
|
||||||
"text-bison": {
|
"text-bison": {
|
||||||
|
@ -1045,8 +1078,8 @@
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 8192,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000625,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.000001875,
|
||||||
"litellm_provider": "vertex_ai-language-models",
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
@ -1057,8 +1090,8 @@
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 8192,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000625,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.000001875,
|
||||||
"litellm_provider": "vertex_ai-language-models",
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
@ -1069,8 +1102,8 @@
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_input_tokens": 1000000,
|
"max_input_tokens": 1000000,
|
||||||
"max_output_tokens": 8192,
|
"max_output_tokens": 8192,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000625,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.000001875,
|
||||||
"litellm_provider": "vertex_ai-language-models",
|
"litellm_provider": "vertex_ai-language-models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
@ -1142,7 +1175,8 @@
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-haiku@20240307": {
|
"vertex_ai/claude-3-haiku@20240307": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -1152,7 +1186,8 @@
|
||||||
"output_cost_per_token": 0.00000125,
|
"output_cost_per_token": 0.00000125,
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"vertex_ai/claude-3-opus@20240229": {
|
"vertex_ai/claude-3-opus@20240229": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -1162,7 +1197,8 @@
|
||||||
"output_cost_per_token": 0.0000075,
|
"output_cost_per_token": 0.0000075,
|
||||||
"litellm_provider": "vertex_ai-anthropic_models",
|
"litellm_provider": "vertex_ai-anthropic_models",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"textembedding-gecko": {
|
"textembedding-gecko": {
|
||||||
"max_tokens": 3072,
|
"max_tokens": 3072,
|
||||||
|
@ -1581,6 +1617,7 @@
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true,
|
||||||
"tool_use_system_prompt_tokens": 395
|
"tool_use_system_prompt_tokens": 395
|
||||||
},
|
},
|
||||||
"openrouter/google/palm-2-chat-bison": {
|
"openrouter/google/palm-2-chat-bison": {
|
||||||
|
@ -1813,6 +1850,15 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"amazon.titan-embed-text-v2:0": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_input_tokens": 8192,
|
||||||
|
"output_vector_size": 1024,
|
||||||
|
"input_cost_per_token": 0.0000002,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
"mistral.mistral-7b-instruct-v0:2": {
|
"mistral.mistral-7b-instruct-v0:2": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
"max_input_tokens": 32000,
|
"max_input_tokens": 32000,
|
||||||
|
@ -1929,7 +1975,8 @@
|
||||||
"output_cost_per_token": 0.000015,
|
"output_cost_per_token": 0.000015,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"anthropic.claude-3-haiku-20240307-v1:0": {
|
"anthropic.claude-3-haiku-20240307-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -1939,7 +1986,8 @@
|
||||||
"output_cost_per_token": 0.00000125,
|
"output_cost_per_token": 0.00000125,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"anthropic.claude-3-opus-20240229-v1:0": {
|
"anthropic.claude-3-opus-20240229-v1:0": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -1949,7 +1997,8 @@
|
||||||
"output_cost_per_token": 0.000075,
|
"output_cost_per_token": 0.000075,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true,
|
||||||
|
"supports_vision": true
|
||||||
},
|
},
|
||||||
"anthropic.claude-v1": {
|
"anthropic.claude-v1": {
|
||||||
"max_tokens": 8191,
|
"max_tokens": 8191,
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
||||||
|
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{87421:function(n,e,t){Promise.resolve().then(t.t.bind(t,99646,23)),Promise.resolve().then(t.t.bind(t,63385,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_c23dc8', '__Inter_Fallback_c23dc8'",fontStyle:"normal"},className:"__className_c23dc8"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=87421)}),_N_E=n.O()}]);
|
|
@ -1 +0,0 @@
|
||||||
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{93553:function(n,e,t){Promise.resolve().then(t.t.bind(t,63385,23)),Promise.resolve().then(t.t.bind(t,99646,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_12bbc4', '__Inter_Fallback_12bbc4'",fontStyle:"normal"},className:"__className_12bbc4"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=93553)}),_N_E=n.O()}]);
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue