forked from phoenix/litellm-mirror
Merge branch 'main' into feat/add-azure-content-filter
This commit is contained in:
commit
bbe1300c5b
200 changed files with 19218 additions and 2966 deletions
|
@ -1,4 +1,4 @@
|
|||
version: 2.1
|
||||
version: 4.3.4
|
||||
jobs:
|
||||
local_testing:
|
||||
docker:
|
||||
|
@ -188,7 +188,7 @@ jobs:
|
|||
command: |
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
|
@ -198,6 +198,7 @@ jobs:
|
|||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e AUTO_INFER_REGION=True \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||
|
@ -208,9 +209,7 @@ jobs:
|
|||
my-app:latest \
|
||||
--config /app/config.yaml \
|
||||
--port 4000 \
|
||||
--num_workers 8 \
|
||||
--detailed_debug \
|
||||
--run_gunicorn \
|
||||
- run:
|
||||
name: Install curl and dockerize
|
||||
command: |
|
||||
|
@ -225,7 +224,7 @@ jobs:
|
|||
background: true
|
||||
- run:
|
||||
name: Wait for app to be ready
|
||||
command: dockerize -wait http://localhost:4000 -timeout 1m
|
||||
command: dockerize -wait http://localhost:4000 -timeout 5m
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
|
|
51
.devcontainer/devcontainer.json
Normal file
51
.devcontainer/devcontainer.json
Normal file
|
@ -0,0 +1,51 @@
|
|||
{
|
||||
"name": "Python 3.11",
|
||||
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
||||
"image": "mcr.microsoft.com/devcontainers/python:3.11-bookworm",
|
||||
// https://github.com/devcontainers/images/tree/main/src/python
|
||||
// https://mcr.microsoft.com/en-us/product/devcontainers/python/tags
|
||||
|
||||
// "build": {
|
||||
// "dockerfile": "Dockerfile",
|
||||
// "context": ".."
|
||||
// },
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
// Configure tool-specific properties.
|
||||
"customizations": {
|
||||
// Configure properties specific to VS Code.
|
||||
"vscode": {
|
||||
"settings": {},
|
||||
"extensions": [
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"GitHub.copilot",
|
||||
"GitHub.copilot-chat"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
"forwardPorts": [4000],
|
||||
|
||||
"containerEnv": {
|
||||
"LITELLM_LOG": "DEBUG"
|
||||
},
|
||||
|
||||
// Use 'portsAttributes' to set default properties for specific forwarded ports.
|
||||
// More info: https://containers.dev/implementors/json_reference/#port-attributes
|
||||
"portsAttributes": {
|
||||
"4000": {
|
||||
"label": "LiteLLM Server",
|
||||
"onAutoForward": "notify"
|
||||
}
|
||||
},
|
||||
|
||||
// More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "litellm",
|
||||
|
||||
// Use 'postCreateCommand' to run commands after the container is created.
|
||||
"postCreateCommand": "pipx install poetry && poetry install -E extra_proxy -E proxy"
|
||||
}
|
19
.github/workflows/interpret_load_test.py
vendored
19
.github/workflows/interpret_load_test.py
vendored
|
@ -64,6 +64,11 @@ if __name__ == "__main__":
|
|||
) # Replace with your repository's username and name
|
||||
latest_release = repo.get_latest_release()
|
||||
print("got latest release: ", latest_release)
|
||||
print(latest_release.title)
|
||||
print(latest_release.tag_name)
|
||||
|
||||
release_version = latest_release.title
|
||||
|
||||
print("latest release body: ", latest_release.body)
|
||||
print("markdown table: ", markdown_table)
|
||||
|
||||
|
@ -74,8 +79,22 @@ if __name__ == "__main__":
|
|||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||
existing_release_body = latest_release.body[:start_index]
|
||||
|
||||
docker_run_command = f"""
|
||||
\n\n
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run \\
|
||||
-e STORE_MODEL_IN_DB=True \\
|
||||
-p 4000:4000 \\
|
||||
ghcr.io/berriai/litellm:main-{release_version}
|
||||
```
|
||||
"""
|
||||
print("docker run command: ", docker_run_command)
|
||||
|
||||
new_release_body = (
|
||||
existing_release_body
|
||||
+ docker_run_command
|
||||
+ "\n\n"
|
||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
|||
.venv
|
||||
.env
|
||||
litellm/proxy/myenv/*
|
||||
litellm_uuid.txt
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
|
|||
litellm/proxy/_new_secret_config.yaml
|
||||
litellm/proxy/_super_secret_config.yaml
|
||||
litellm/proxy/_super_secret_config.yaml
|
||||
litellm/proxy/myenv/bin/activate
|
||||
litellm/proxy/myenv/bin/Activate.ps1
|
||||
myenv/*
|
|
@ -226,6 +226,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
|||
| [deepinfra](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [perplexity-ai](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Groq AI](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Deepseek](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [IBM - watsonx.ai](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | ✅
|
||||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ |
|
||||
|
|
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
BIN
deploy/azure_resource_manager/azure_marketplace.zip
Normal file
Binary file not shown.
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"$schema": "https://schema.management.azure.com/schemas/0.1.2-preview/CreateUIDefinition.MultiVm.json#",
|
||||
"handler": "Microsoft.Azure.CreateUIDef",
|
||||
"version": "0.1.2-preview",
|
||||
"parameters": {
|
||||
"config": {
|
||||
"isWizard": false,
|
||||
"basics": { }
|
||||
},
|
||||
"basics": [ ],
|
||||
"steps": [ ],
|
||||
"outputs": { },
|
||||
"resourceTypes": [ ]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
{
|
||||
"$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
|
||||
"contentVersion": "1.0.0.0",
|
||||
"parameters": {
|
||||
"imageName": {
|
||||
"type": "string",
|
||||
"defaultValue": "ghcr.io/berriai/litellm:main-latest"
|
||||
},
|
||||
"containerName": {
|
||||
"type": "string",
|
||||
"defaultValue": "litellm-container"
|
||||
},
|
||||
"dnsLabelName": {
|
||||
"type": "string",
|
||||
"defaultValue": "litellm"
|
||||
},
|
||||
"portNumber": {
|
||||
"type": "int",
|
||||
"defaultValue": 4000
|
||||
}
|
||||
},
|
||||
"resources": [
|
||||
{
|
||||
"type": "Microsoft.ContainerInstance/containerGroups",
|
||||
"apiVersion": "2021-03-01",
|
||||
"name": "[parameters('containerName')]",
|
||||
"location": "[resourceGroup().location]",
|
||||
"properties": {
|
||||
"containers": [
|
||||
{
|
||||
"name": "[parameters('containerName')]",
|
||||
"properties": {
|
||||
"image": "[parameters('imageName')]",
|
||||
"resources": {
|
||||
"requests": {
|
||||
"cpu": 1,
|
||||
"memoryInGB": 2
|
||||
}
|
||||
},
|
||||
"ports": [
|
||||
{
|
||||
"port": "[parameters('portNumber')]"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"osType": "Linux",
|
||||
"restartPolicy": "Always",
|
||||
"ipAddress": {
|
||||
"type": "Public",
|
||||
"ports": [
|
||||
{
|
||||
"protocol": "tcp",
|
||||
"port": "[parameters('portNumber')]"
|
||||
}
|
||||
],
|
||||
"dnsNameLabel": "[parameters('dnsLabelName')]"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
42
deploy/azure_resource_manager/main.bicep
Normal file
42
deploy/azure_resource_manager/main.bicep
Normal file
|
@ -0,0 +1,42 @@
|
|||
param imageName string = 'ghcr.io/berriai/litellm:main-latest'
|
||||
param containerName string = 'litellm-container'
|
||||
param dnsLabelName string = 'litellm'
|
||||
param portNumber int = 4000
|
||||
|
||||
resource containerGroupName 'Microsoft.ContainerInstance/containerGroups@2021-03-01' = {
|
||||
name: containerName
|
||||
location: resourceGroup().location
|
||||
properties: {
|
||||
containers: [
|
||||
{
|
||||
name: containerName
|
||||
properties: {
|
||||
image: imageName
|
||||
resources: {
|
||||
requests: {
|
||||
cpu: 1
|
||||
memoryInGB: 2
|
||||
}
|
||||
}
|
||||
ports: [
|
||||
{
|
||||
port: portNumber
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
osType: 'Linux'
|
||||
restartPolicy: 'Always'
|
||||
ipAddress: {
|
||||
type: 'Public'
|
||||
ports: [
|
||||
{
|
||||
protocol: 'tcp'
|
||||
port: portNumber
|
||||
}
|
||||
]
|
||||
dnsNameLabel: dnsLabelName
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,7 +24,7 @@ version: 0.2.0
|
|||
# incremented each time you make changes to the application. Versions are not expected to
|
||||
# follow Semantic Versioning. They should reflect the version the application is using.
|
||||
# It is recommended to use it with quotes.
|
||||
appVersion: v1.24.5
|
||||
appVersion: v1.35.38
|
||||
|
||||
dependencies:
|
||||
- name: "postgresql"
|
||||
|
|
|
@ -83,8 +83,9 @@ def completion(
|
|||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
|
@ -139,6 +140,10 @@ def completion(
|
|||
|
||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||
|
||||
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||
|
||||
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||
|
||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Completion Token Usage & Cost
|
||||
By default LiteLLM returns token usage in all completion requests ([See here](https://litellm.readthedocs.io/en/latest/output/))
|
||||
|
||||
However, we also expose 5 helper functions + **[NEW]** an API to calculate token usage across providers:
|
||||
However, we also expose some helper functions + **[NEW]** an API to calculate token usage across providers:
|
||||
|
||||
- `encode`: This encodes the text passed in, using the model-specific tokenizer. [**Jump to code**](#1-encode)
|
||||
|
||||
|
@ -9,17 +9,19 @@ However, we also expose 5 helper functions + **[NEW]** an API to calculate token
|
|||
|
||||
- `token_counter`: This returns the number of tokens for a given input - it uses the tokenizer based on the model, and defaults to tiktoken if no model-specific tokenizer is available. [**Jump to code**](#3-token_counter)
|
||||
|
||||
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#4-cost_per_token)
|
||||
- `create_pretrained_tokenizer` and `create_tokenizer`: LiteLLM provides default tokenizer support for OpenAI, Cohere, Anthropic, Llama2, and Llama3 models. If you are using a different model, you can create a custom tokenizer and pass it as `custom_tokenizer` to the `encode`, `decode`, and `token_counter` methods. [**Jump to code**](#4-create_pretrained_tokenizer-and-create_tokenizer)
|
||||
|
||||
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#5-completion_cost)
|
||||
- `cost_per_token`: This returns the cost (in USD) for prompt (input) and completion (output) tokens. Uses the live list from `api.litellm.ai`. [**Jump to code**](#5-cost_per_token)
|
||||
|
||||
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#6-get_max_tokens)
|
||||
- `completion_cost`: This returns the overall cost (in USD) for a given LLM API Call. It combines `token_counter` and `cost_per_token` to return the cost for that query (counting both cost of input and output). [**Jump to code**](#6-completion_cost)
|
||||
|
||||
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#7-model_cost)
|
||||
- `get_max_tokens`: This returns the maximum number of tokens allowed for the given model. [**Jump to code**](#7-get_max_tokens)
|
||||
|
||||
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#8-register_model)
|
||||
- `model_cost`: This returns a dictionary for all models, with their max_tokens, input_cost_per_token and output_cost_per_token. It uses the `api.litellm.ai` call shown below. [**Jump to code**](#8-model_cost)
|
||||
|
||||
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#9-apilitellmai)
|
||||
- `register_model`: This registers new / overrides existing models (and their pricing details) in the model cost dictionary. [**Jump to code**](#9-register_model)
|
||||
|
||||
- `api.litellm.ai`: Live token + price count across [all supported models](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). [**Jump to code**](#10-apilitellmai)
|
||||
|
||||
📣 This is a community maintained list. Contributions are welcome! ❤️
|
||||
|
||||
|
@ -60,7 +62,24 @@ messages = [{"user": "role", "content": "Hey, how's it going"}]
|
|||
print(token_counter(model="gpt-3.5-turbo", messages=messages))
|
||||
```
|
||||
|
||||
### 4. `cost_per_token`
|
||||
### 4. `create_pretrained_tokenizer` and `create_tokenizer`
|
||||
|
||||
```python
|
||||
from litellm import create_pretrained_tokenizer, create_tokenizer
|
||||
|
||||
# get tokenizer from huggingface repo
|
||||
custom_tokenizer_1 = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
|
||||
|
||||
# use tokenizer from json file
|
||||
with open("tokenizer.json") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
custom_tokenizer_2 = create_tokenizer(json_str)
|
||||
```
|
||||
|
||||
### 5. `cost_per_token`
|
||||
|
||||
```python
|
||||
from litellm import cost_per_token
|
||||
|
@ -72,7 +91,7 @@ prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_toke
|
|||
print(prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar)
|
||||
```
|
||||
|
||||
### 5. `completion_cost`
|
||||
### 6. `completion_cost`
|
||||
|
||||
* Input: Accepts a `litellm.completion()` response **OR** prompt + completion strings
|
||||
* Output: Returns a `float` of cost for the `completion` call
|
||||
|
@ -99,7 +118,7 @@ cost = completion_cost(model="bedrock/anthropic.claude-v2", prompt="Hey!", compl
|
|||
formatted_string = f"${float(cost):.10f}"
|
||||
print(formatted_string)
|
||||
```
|
||||
### 6. `get_max_tokens`
|
||||
### 7. `get_max_tokens`
|
||||
|
||||
Input: Accepts a model name - e.g., gpt-3.5-turbo (to get a complete list, call litellm.model_list).
|
||||
Output: Returns the maximum number of tokens allowed for the given model
|
||||
|
@ -112,7 +131,7 @@ model = "gpt-3.5-turbo"
|
|||
print(get_max_tokens(model)) # Output: 4097
|
||||
```
|
||||
|
||||
### 7. `model_cost`
|
||||
### 8. `model_cost`
|
||||
|
||||
* Output: Returns a dict object containing the max_tokens, input_cost_per_token, output_cost_per_token for all models on [community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
|
||||
|
@ -122,7 +141,7 @@ from litellm import model_cost
|
|||
print(model_cost) # {'gpt-3.5-turbo': {'max_tokens': 4000, 'input_cost_per_token': 1.5e-06, 'output_cost_per_token': 2e-06}, ...}
|
||||
```
|
||||
|
||||
### 8. `register_model`
|
||||
### 9. `register_model`
|
||||
|
||||
* Input: Provide EITHER a model cost dictionary or a url to a hosted json blob
|
||||
* Output: Returns updated model_cost dictionary + updates litellm.model_cost with model details.
|
||||
|
@ -157,5 +176,3 @@ export LITELLM_LOCAL_MODEL_COST_MAP="True"
|
|||
```
|
||||
|
||||
Note: this means you will need to upgrade to get updated pricing, and newer models.
|
||||
|
||||
|
||||
|
|
|
@ -320,8 +320,6 @@ from litellm import embedding
|
|||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||
litellm.vertex_location = "us-central1" # proj location
|
||||
|
||||
|
||||
os.environ['VOYAGE_API_KEY'] = ""
|
||||
response = embedding(
|
||||
model="vertex_ai/textembedding-gecko",
|
||||
input=["good morning from litellm"],
|
||||
|
|
|
@ -17,6 +17,14 @@ This covers:
|
|||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
||||
|
||||
|
||||
## [COMING SOON] AWS Marketplace Support
|
||||
|
||||
Deploy managed LiteLLM Proxy within your VPC.
|
||||
|
||||
Includes all enterprise features.
|
||||
|
||||
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
### What topics does Professional support cover and what SLAs do you offer?
|
||||
|
|
|
@ -13,7 +13,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
|||
| >=500 | InternalServerError |
|
||||
| N/A | ContextWindowExceededError|
|
||||
| 400 | ContentPolicyViolationError|
|
||||
| N/A | APIConnectionError |
|
||||
| 500 | APIConnectionError |
|
||||
|
||||
|
||||
Base case we return APIConnectionError
|
||||
|
@ -74,6 +74,28 @@ except Exception as e:
|
|||
|
||||
```
|
||||
|
||||
## Usage - Should you retry exception?
|
||||
|
||||
```
|
||||
import litellm
|
||||
import openai
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello, write a 20 pageg essay"
|
||||
}
|
||||
],
|
||||
timeout=0.01, # this will raise a timeout exception
|
||||
)
|
||||
except openai.APITimeoutError as e:
|
||||
should_retry = litellm._should_retry(e.status_code)
|
||||
print(f"should_retry: {should_retry}")
|
||||
```
|
||||
|
||||
## Details
|
||||
|
||||
To see how it's implemented - [check out the code](https://github.com/BerriAI/litellm/blob/a42c197e5a6de56ea576c73715e6c7c6b19fa249/litellm/utils.py#L1217)
|
||||
|
@ -86,21 +108,34 @@ To see how it's implemented - [check out the code](https://github.com/BerriAI/li
|
|||
|
||||
Base case - we return the original exception.
|
||||
|
||||
| | ContextWindowExceededError | AuthenticationError | InvalidRequestError | RateLimitError | ServiceUnavailableError |
|
||||
|---------------|----------------------------|---------------------|---------------------|---------------|-------------------------|
|
||||
| Anthropic | ✅ | ✅ | ✅ | ✅ | |
|
||||
| OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
||||
| Azure OpenAI | ✅ | ✅ |✅ |✅ |✅|
|
||||
| Replicate | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Cohere | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Huggingface | ✅ | ✅ | ✅ | ✅ | |
|
||||
| Openrouter | ✅ | ✅ | ✅ | ✅ | |
|
||||
| AI21 | ✅ | ✅ | ✅ | ✅ | |
|
||||
| VertexAI | | |✅ | | |
|
||||
| Bedrock | | |✅ | | |
|
||||
| Sagemaker | | |✅ | | |
|
||||
| TogetherAI | ✅ | ✅ | ✅ | ✅ | |
|
||||
| AlephAlpha | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| custom_llm_provider | Timeout | ContextWindowExceededError | BadRequestError | NotFoundError | ContentPolicyViolationError | AuthenticationError | APIError | RateLimitError | ServiceUnavailableError | PermissionDeniedError | UnprocessableEntityError |
|
||||
|----------------------------|---------|----------------------------|------------------|---------------|-----------------------------|---------------------|----------|----------------|-------------------------|-----------------------|-------------------------|
|
||||
| openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| text-completion-openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| custom_openai | ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| openai_compatible_providers| ✓ | ✓ | ✓ | | ✓ | ✓ | | | | | |
|
||||
| anthropic | ✓ | ✓ | ✓ | ✓ | | ✓ | | | ✓ | ✓ | |
|
||||
| replicate | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | | |
|
||||
| bedrock | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | ✓ | ✓ | |
|
||||
| sagemaker | | ✓ | ✓ | | | | | | | | |
|
||||
| vertex_ai | ✓ | | ✓ | | | | ✓ | | | | ✓ |
|
||||
| palm | ✓ | ✓ | | | | | ✓ | | | | |
|
||||
| gemini | ✓ | ✓ | | | | | ✓ | | | | |
|
||||
| cloudflare | | | ✓ | | | ✓ | | | | | |
|
||||
| cohere | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||
| cohere_chat | | ✓ | ✓ | | | ✓ | | | ✓ | | |
|
||||
| huggingface | ✓ | ✓ | ✓ | | | ✓ | | ✓ | ✓ | | |
|
||||
| ai21 | ✓ | ✓ | ✓ | ✓ | | ✓ | | ✓ | | | |
|
||||
| nlp_cloud | ✓ | ✓ | ✓ | | | ✓ | ✓ | ✓ | ✓ | | |
|
||||
| together_ai | ✓ | ✓ | ✓ | | | ✓ | | | | | |
|
||||
| aleph_alpha | | | ✓ | | | ✓ | | | | | |
|
||||
| ollama | ✓ | | ✓ | | | | | | ✓ | | |
|
||||
| ollama_chat | ✓ | | ✓ | | | | | | ✓ | | |
|
||||
| vllm | | | | | | ✓ | ✓ | | | | |
|
||||
| azure | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | | ✓ | | |
|
||||
|
||||
- "✓" indicates that the specified `custom_llm_provider` can raise the corresponding exception.
|
||||
- Empty cells indicate the lack of association or that the provider does not raise that particular exception type as indicated by the function.
|
||||
|
||||
|
||||
> For a deeper understanding of these exceptions, you can check out [this](https://github.com/BerriAI/litellm/blob/d7e58d13bf9ba9edbab2ab2f096f3de7547f35fa/litellm/utils.py#L1544) implementation for additional insights.
|
||||
|
|
|
@ -46,4 +46,13 @@ Pricing is based on usage. We can figure out a price that works for your team, o
|
|||
|
||||
<Image img={require('../img/litellm_hosted_ui_router.png')} />
|
||||
|
||||
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
#### [**🚨 Schedule Call**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
## Feature List
|
||||
|
||||
- Easy way to add/remove models
|
||||
- 100% uptime even when models are added/removed
|
||||
- custom callback webhooks
|
||||
- your domain name with HTTPS
|
||||
- Ability to create/delete User API keys
|
||||
- Reasonable set monthly cost
|
|
@ -14,14 +14,14 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
```python
|
||||
import os
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ""
|
||||
chat = ChatLiteLLM(model="gpt-3.5-turbo")
|
||||
|
@ -30,7 +30,7 @@ messages = [
|
|||
content="what model are you"
|
||||
)
|
||||
]
|
||||
chat(messages)
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
@ -39,14 +39,14 @@ chat(messages)
|
|||
|
||||
```python
|
||||
import os
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ['ANTHROPIC_API_KEY'] = ""
|
||||
chat = ChatLiteLLM(model="claude-2", temperature=0.3)
|
||||
|
@ -55,7 +55,7 @@ messages = [
|
|||
content="what model are you"
|
||||
)
|
||||
]
|
||||
chat(messages)
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
@ -64,14 +64,14 @@ chat(messages)
|
|||
|
||||
```python
|
||||
import os
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ['REPLICATE_API_TOKEN'] = ""
|
||||
chat = ChatLiteLLM(model="replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1")
|
||||
|
@ -80,7 +80,7 @@ messages = [
|
|||
content="what model are you?"
|
||||
)
|
||||
]
|
||||
chat(messages)
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
@ -89,14 +89,14 @@ chat(messages)
|
|||
|
||||
```python
|
||||
import os
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ['COHERE_API_KEY'] = ""
|
||||
chat = ChatLiteLLM(model="command-nightly")
|
||||
|
@ -105,32 +105,9 @@ messages = [
|
|||
content="what model are you?"
|
||||
)
|
||||
]
|
||||
chat(messages)
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="palm" label="PaLM - Google">
|
||||
|
||||
```python
|
||||
import os
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
os.environ['PALM_API_KEY'] = ""
|
||||
chat = ChatLiteLLM(model="palm/chat-bison")
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content="what model are you?"
|
||||
)
|
||||
]
|
||||
chat(messages)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
|
|
@ -94,9 +94,10 @@ print(response)
|
|||
|
||||
```
|
||||
|
||||
### Set Custom Trace ID, Trace User ID and Tags
|
||||
### Set Custom Trace ID, Trace User ID, Trace Metadata, Trace Version, Trace Release and Tags
|
||||
|
||||
Pass `trace_id`, `trace_user_id`, `trace_metadata`, `trace_version`, `trace_release`, `tags` in `metadata`
|
||||
|
||||
Pass `trace_id`, `trace_user_id` in `metadata`
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
@ -121,12 +122,20 @@ response = completion(
|
|||
metadata={
|
||||
"generation_name": "ishaan-test-generation", # set langfuse Generation Name
|
||||
"generation_id": "gen-id22", # set langfuse Generation ID
|
||||
"version": "test-generation-version" # set langfuse Generation Version
|
||||
"trace_user_id": "user-id2", # set langfuse Trace User ID
|
||||
"session_id": "session-1", # set langfuse Session ID
|
||||
"tags": ["tag1", "tag2"] # set langfuse Tags
|
||||
"tags": ["tag1", "tag2"], # set langfuse Tags
|
||||
"trace_id": "trace-id22", # set langfuse Trace ID
|
||||
"trace_metadata": {"key": "value"}, # set langfuse Trace Metadata
|
||||
"trace_version": "test-trace-version", # set langfuse Trace Version (if not set, defaults to Generation Version)
|
||||
"trace_release": "test-trace-release", # set langfuse Trace Release
|
||||
### OR ###
|
||||
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
||||
"existing_trace_id": "trace-id22", # if generation is continuation of past trace. This prevents default behaviour of setting a trace name
|
||||
### OR enforce that certain fields are trace overwritten in the trace during the continuation ###
|
||||
"existing_trace_id": "trace-id22",
|
||||
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
|
||||
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -134,6 +143,38 @@ print(response)
|
|||
|
||||
```
|
||||
|
||||
### Trace & Generation Parameters
|
||||
|
||||
#### Trace Specific Parameters
|
||||
|
||||
* `trace_id` - Identifier for the trace, must use `existing_trace_id` instead or in conjunction with `trace_id` if this is an existing trace, auto-generated by default
|
||||
* `trace_name` - Name of the trace, auto-generated by default
|
||||
* `session_id` - Session identifier for the trace, defaults to `None`
|
||||
* `trace_version` - Version for the trace, defaults to value for `version`
|
||||
* `trace_release` - Release for the trace, defaults to `None`
|
||||
* `trace_metadata` - Metadata for the trace, defaults to `None`
|
||||
* `trace_user_id` - User identifier for the trace, defaults to completion argument `user`
|
||||
* `tags` - Tags for the trace, defeaults to `None`
|
||||
|
||||
##### Updatable Parameters on Continuation
|
||||
|
||||
The following parameters can be updated on a continuation of a trace by passing in the following values into the `update_trace_keys` in the metadata of the completion.
|
||||
|
||||
* `input` - Will set the traces input to be the input of this latest generation
|
||||
* `output` - Will set the traces output to be the output of this generation
|
||||
* `trace_version` - Will set the trace version to be the provided value (To use the latest generations version instead, use `version`)
|
||||
* `trace_release` - Will set the trace release to be the provided value
|
||||
* `trace_metadata` - Will set the trace metadata to the provided value
|
||||
* `trace_user_id` - Will set the trace user id to the provided value
|
||||
|
||||
#### Generation Specific Parameters
|
||||
|
||||
* `generation_id` - Identifier for the generation, auto-generated by default
|
||||
* `generation_name` - Identifier for the generation, auto-generated by default
|
||||
* `prompt` - Langfuse prompt object used for the generation, defaults to None
|
||||
|
||||
Any other key value pairs passed into the metadata not listed in the above spec for a `litellm` completion will be added as a metadata key value pair for the generation.
|
||||
|
||||
### Use LangChain ChatLiteLLM + Langfuse
|
||||
Pass `trace_user_id`, `session_id` in model_kwargs
|
||||
```python
|
||||
|
|
|
@ -535,7 +535,8 @@ print(response)
|
|||
|
||||
| Model Name | Function Call |
|
||||
|----------------------|---------------------------------------------|
|
||||
| Titan Embeddings - G1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
||||
| Titan Embeddings V2 | `embedding(model="bedrock/amazon.titan-embed-text-v2:0", input=input)` |
|
||||
| Titan Embeddings - V1 | `embedding(model="bedrock/amazon.titan-embed-text-v1", input=input)` |
|
||||
| Cohere Embeddings - English | `embedding(model="bedrock/cohere.embed-english-v3", input=input)` |
|
||||
| Cohere Embeddings - Multilingual | `embedding(model="bedrock/cohere.embed-multilingual-v3", input=input)` |
|
||||
|
||||
|
|
54
docs/my-website/docs/providers/deepseek.md
Normal file
54
docs/my-website/docs/providers/deepseek.md
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Deepseek
|
||||
https://deepseek.com/
|
||||
|
||||
**We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests**
|
||||
|
||||
## API Key
|
||||
```python
|
||||
# env variable
|
||||
os.environ['DEEPSEEK_API_KEY']
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="deepseek/deepseek-chat",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Sample Usage - Streaming
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="deepseek/deepseek-chat",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## Supported Models - ALL Deepseek Models Supported!
|
||||
We support ALL Deepseek models, just set `deepseek/` as a prefix when sending completion requests
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| deepseek-chat | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||
| deepseek-coder | `completion(model="deepseek/deepseek-chat", messages)` |
|
||||
|
||||
|
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
|||
<Tabs>
|
||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||
|
||||
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -40,9 +45,58 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: wizard-coder
|
||||
litellm_params:
|
||||
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "wizard-coder",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||
|
||||
Append `conversational` to the model name
|
||||
|
||||
e.g. `huggingface/conversational/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/facebook/blenderbot-400M-distill",
|
||||
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.huggingface.cloud"
|
||||
)
|
||||
|
@ -62,7 +116,123 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: blenderbot
|
||||
litellm_params:
|
||||
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "blenderbot",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="classification" label="Text Classification">
|
||||
|
||||
Append `text-classification` to the model name
|
||||
|
||||
e.g. `huggingface/text-classification/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
# [OPTIONAL] set env var
|
||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||
|
||||
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||
|
||||
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bert-classifier
|
||||
litellm_params:
|
||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "bert-classifier",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||
|
||||
Append `text-generation` to the model name
|
||||
|
||||
e.g. `huggingface/text-generation/<model-name>`
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/roneneldan/TinyStories-3M",
|
||||
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||
messages=messages,
|
||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
|
|
@ -44,14 +44,14 @@ for chunk in response:
|
|||
## Supported Models
|
||||
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
|
||||
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
||||
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
||||
| mistral-large-latest | `completion(model="mistral/mistral-large-latest", messages)` |
|
||||
| open-mixtral-8x22b | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||
|
||||
| Model Name | Function Call |
|
||||
|----------------|--------------------------------------------------------------|
|
||||
| Mistral Small | `completion(model="mistral/mistral-small-latest", messages)` |
|
||||
| Mistral Medium | `completion(model="mistral/mistral-medium-latest", messages)`|
|
||||
| Mistral Large | `completion(model="mistral/mistral-large-latest", messages)` |
|
||||
| Mistral 7B | `completion(model="mistral/open-mistral-7b", messages)` |
|
||||
| Mixtral 8x7B | `completion(model="mistral/open-mixtral-8x7b", messages)` |
|
||||
| Mixtral 8x22B | `completion(model="mistral/open-mixtral-8x22b", messages)` |
|
||||
|
||||
## Function Calling
|
||||
|
||||
|
@ -116,6 +116,6 @@ All models listed here https://docs.mistral.ai/platform/endpoints are supported
|
|||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| mistral-embed | `embedding(model="mistral/mistral-embed", input)` |
|
||||
| Mistral Embeddings | `embedding(model="mistral/mistral-embed", input)` |
|
||||
|
||||
|
||||
|
|
247
docs/my-website/docs/providers/predibase.md
Normal file
247
docs/my-website/docs/providers/predibase.md
Normal file
|
@ -0,0 +1,247 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 🆕 Predibase
|
||||
|
||||
LiteLLM supports all models on Predibase
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
### API KEYS
|
||||
```python
|
||||
import os
|
||||
os.environ["PREDIBASE_API_KEY"] = ""
|
||||
```
|
||||
|
||||
### Example Call
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
## set ENV variables
|
||||
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||
os.environ["PREDIBASE_TENANT_ID"] = "predibase tenant id"
|
||||
|
||||
# predibase llama-3 call
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: predibase/llama-3-8b-instruct
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys
|
||||
base_url="http://0.0.0.0:4000" # litellm-proxy-base url
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="llama-3",
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be a good human!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you know about earth?"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "llama-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be a good human!"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you know about earth?"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Advanced Usage - Prompt Formatting
|
||||
|
||||
LiteLLM has prompt template mappings for all `meta-llama` llama3 instruct models. [**See Code**](https://github.com/BerriAI/litellm/blob/4f46b4c3975cd0f72b8c5acb2cb429d23580c18a/litellm/llms/prompt_templates/factory.py#L1360)
|
||||
|
||||
To apply a custom prompt template:
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
import os
|
||||
os.environ["PREDIBASE_API_KEY"] = ""
|
||||
|
||||
# Create your own custom prompt template
|
||||
litellm.register_prompt_template(
|
||||
model="togethercomputer/LLaMA-2-7B-32K",
|
||||
initial_prompt_value="You are a good assistant" # [OPTIONAL]
|
||||
roles={
|
||||
"system": {
|
||||
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "[INST] ", # [OPTIONAL]
|
||||
"post_message": " [/INST]" # [OPTIONAL]
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "\n" # [OPTIONAL]
|
||||
"post_message": "\n" # [OPTIONAL]
|
||||
}
|
||||
}
|
||||
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||
)
|
||||
|
||||
def predibase_custom_model():
|
||||
model = "predibase/togethercomputer/LLaMA-2-7B-32K"
|
||||
response = completion(model=model, messages=messages)
|
||||
print(response['choices'][0]['message']['content'])
|
||||
return response
|
||||
|
||||
predibase_custom_model()
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
# Model-specific parameters
|
||||
model_list:
|
||||
- model_name: mistral-7b # model alias
|
||||
litellm_params: # actual params for litellm.completion()
|
||||
model: "predibase/mistralai/Mistral-7B-Instruct-v0.1"
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
initial_prompt_value: "\n"
|
||||
roles: {"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}
|
||||
final_prompt_value: "\n"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
max_tokens: 4096
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
## Passing additional params - max_tokens, temperature
|
||||
See all litellm.completion supported params [here](https://docs.litellm.ai/docs/completion/input)
|
||||
|
||||
```python
|
||||
# !pip install litellm
|
||||
from litellm import completion
|
||||
import os
|
||||
## set ENV variables
|
||||
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||
|
||||
# predibae llama-3 call
|
||||
response = completion(
|
||||
model="predibase/llama3-8b-instruct",
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||
max_tokens=20,
|
||||
temperature=0.5
|
||||
)
|
||||
```
|
||||
|
||||
**proxy**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: predibase/llama-3-8b-instruct
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
max_tokens: 20
|
||||
temperature: 0.5
|
||||
```
|
||||
|
||||
## Passings Predibase specific params - adapter_id, adapter_source,
|
||||
Send params [not supported by `litellm.completion()`](https://docs.litellm.ai/docs/completion/input) but supported by Predibase by passing them to `litellm.completion`
|
||||
|
||||
Example `adapter_id`, `adapter_source` are Predibase specific param - [See List](https://github.com/BerriAI/litellm/blob/8a35354dd6dbf4c2fcefcd6e877b980fcbd68c58/litellm/llms/predibase.py#L54)
|
||||
|
||||
```python
|
||||
# !pip install litellm
|
||||
from litellm import completion
|
||||
import os
|
||||
## set ENV variables
|
||||
os.environ["PREDIBASE_API_KEY"] = "predibase key"
|
||||
|
||||
# predibase llama3 call
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
messages = [{ "content": "Hello, how are you?","role": "user"}],
|
||||
adapter_id="my_repo/3",
|
||||
adapter_soruce="pbase",
|
||||
)
|
||||
```
|
||||
|
||||
**proxy**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: llama-3
|
||||
litellm_params:
|
||||
model: predibase/llama-3-8b-instruct
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
adapter_id: my_repo/3
|
||||
adapter_source: pbase
|
||||
```
|
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
95
docs/my-website/docs/providers/triton-inference-server.md
Normal file
|
@ -0,0 +1,95 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Triton Inference Server
|
||||
|
||||
LiteLLM supports Embedding Models on Triton Inference Servers
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
|
||||
### Example Call
|
||||
|
||||
Use the `triton/` prefix to route to triton server
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
response = await litellm.aembedding(
|
||||
model="triton/<your-triton-model>",
|
||||
api_base="https://your-triton-api-base/triton/embeddings", # /embeddings endpoint you want litellm to call on your server
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: my-triton-model
|
||||
litellm_params:
|
||||
model: triton/<your-triton-model>"
|
||||
api_base: https://your-triton-api-base/triton/embeddings
|
||||
```
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
3. Send Request to LiteLLM Proxy Server
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Python v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=["hello from litellm"],
|
||||
model="my-triton-model"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="curl">
|
||||
|
||||
`--header` is optional, only required if you're using litellm proxy with Virtual Keys
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data ' {
|
||||
"model": "my-triton-model",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
|
@ -477,6 +477,36 @@ print(response)
|
|||
| code-gecko@latest| `completion('code-gecko@latest', messages)` |
|
||||
|
||||
|
||||
## Embedding Models
|
||||
|
||||
#### Usage - Embedding
|
||||
```python
|
||||
import litellm
|
||||
from litellm import embedding
|
||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||
litellm.vertex_location = "us-central1" # proj location
|
||||
|
||||
response = embedding(
|
||||
model="vertex_ai/textembedding-gecko",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
#### Supported Embedding Models
|
||||
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| textembedding-gecko | `embedding(model="vertex_ai/textembedding-gecko", input)` |
|
||||
| textembedding-gecko-multilingual | `embedding(model="vertex_ai/textembedding-gecko-multilingual", input)` |
|
||||
| textembedding-gecko-multilingual@001 | `embedding(model="vertex_ai/textembedding-gecko-multilingual@001", input)` |
|
||||
| textembedding-gecko@001 | `embedding(model="vertex_ai/textembedding-gecko@001", input)` |
|
||||
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
|
||||
|
||||
## Extra
|
||||
|
||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||
|
@ -520,6 +550,12 @@ def load_vertex_ai_credentials():
|
|||
|
||||
### Using GCP Service Account
|
||||
|
||||
:::info
|
||||
|
||||
Trying to deploy LiteLLM on Google Cloud Run? Tutorial [here](https://docs.litellm.ai/docs/proxy/deploy#deploy-on-google-cloud-run)
|
||||
|
||||
:::
|
||||
|
||||
1. Figure out the Service Account bound to the Google Cloud Run service
|
||||
|
||||
<Image img={require('../../img/gcp_acc_1.png')} />
|
||||
|
|
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
|
@ -0,0 +1,83 @@
|
|||
# Region-based Routing
|
||||
|
||||
Route specific customers to eu-only models.
|
||||
|
||||
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
|
||||
|
||||
### 1. Create customer with region-specification
|
||||
|
||||
Use the litellm 'end-user' object for this.
|
||||
|
||||
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
|
||||
|
||||
```bash
|
||||
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id" : "ishaan-jaff-45",
|
||||
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Add eu models to model-group
|
||||
|
||||
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||
```
|
||||
|
||||
Start the proxy
|
||||
|
||||
```yaml
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### 3. Test it!
|
||||
|
||||
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
|
||||
|
||||
```bash
|
||||
curl -X POST --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the meaning of the universe? 1234"
|
||||
}],
|
||||
"user": "ishaan-jaff-45" # 👈 USER ID
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
Expected API Base in response headers
|
||||
|
||||
```
|
||||
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
```
|
||||
|
||||
### FAQ
|
||||
|
||||
**What happens if there are no available models for that region?**
|
||||
|
||||
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.
|
|
@ -915,39 +915,72 @@ Test Request
|
|||
litellm --test
|
||||
```
|
||||
|
||||
## Logging Proxy Input/Output Traceloop (OpenTelemetry)
|
||||
## Logging Proxy Input/Output in OpenTelemetry format using Traceloop's OpenLLMetry
|
||||
|
||||
Traceloop allows you to log LLM Input/Output in the OpenTelemetry format
|
||||
[OpenLLMetry](https://github.com/traceloop/openllmetry) _(built and maintained by Traceloop)_ is a set of extensions
|
||||
built on top of [OpenTelemetry](https://opentelemetry.io/) that gives you complete observability over your LLM
|
||||
application. Because it uses OpenTelemetry under the
|
||||
hood, [it can be connected to various observability solutions](https://www.traceloop.com/docs/openllmetry/integrations/introduction)
|
||||
like:
|
||||
|
||||
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` this will log all successfull LLM calls to traceloop
|
||||
* [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop)
|
||||
* [Axiom](https://www.traceloop.com/docs/openllmetry/integrations/axiom)
|
||||
* [Azure Application Insights](https://www.traceloop.com/docs/openllmetry/integrations/azure)
|
||||
* [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
||||
* [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace)
|
||||
* [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana)
|
||||
* [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb)
|
||||
* [HyperDX](https://www.traceloop.com/docs/openllmetry/integrations/hyperdx)
|
||||
* [Instana](https://www.traceloop.com/docs/openllmetry/integrations/instana)
|
||||
* [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic)
|
||||
* [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
||||
* [Service Now Cloud Observability](https://www.traceloop.com/docs/openllmetry/integrations/service-now)
|
||||
* [Sentry](https://www.traceloop.com/docs/openllmetry/integrations/sentry)
|
||||
* [SigNoz](https://www.traceloop.com/docs/openllmetry/integrations/signoz)
|
||||
* [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk)
|
||||
|
||||
**Step 1** Install traceloop-sdk and set Traceloop API key
|
||||
We will use the `--config` to set `litellm.success_callback = ["traceloop"]` to achieve this, steps are listed below.
|
||||
|
||||
**Step 1:** Install the SDK
|
||||
|
||||
```shell
|
||||
pip install traceloop-sdk -U
|
||||
pip install traceloop-sdk
|
||||
```
|
||||
|
||||
Traceloop outputs standard OpenTelemetry data that can be connected to your observability stack. Send standard OpenTelemetry from LiteLLM Proxy to [Traceloop](https://www.traceloop.com/docs/openllmetry/integrations/traceloop), [Dynatrace](https://www.traceloop.com/docs/openllmetry/integrations/dynatrace), [Datadog](https://www.traceloop.com/docs/openllmetry/integrations/datadog)
|
||||
, [New Relic](https://www.traceloop.com/docs/openllmetry/integrations/newrelic), [Honeycomb](https://www.traceloop.com/docs/openllmetry/integrations/honeycomb), [Grafana Tempo](https://www.traceloop.com/docs/openllmetry/integrations/grafana), [Splunk](https://www.traceloop.com/docs/openllmetry/integrations/splunk), [OpenTelemetry Collector](https://www.traceloop.com/docs/openllmetry/integrations/otel-collector)
|
||||
**Step 2:** Configure Environment Variable for trace exporting
|
||||
|
||||
You will need to configure where to export your traces. Environment variables will control this, example: For Traceloop
|
||||
you should use `TRACELOOP_API_KEY`, whereas for Datadog you use `TRACELOOP_BASE_URL`. For more
|
||||
visit [the Integrations Catalog](https://www.traceloop.com/docs/openllmetry/integrations/introduction).
|
||||
|
||||
If you are using Datadog as the observability solutions then you can set `TRACELOOP_BASE_URL` as:
|
||||
|
||||
```shell
|
||||
TRACELOOP_BASE_URL=http://<datadog-agent-hostname>:4318
|
||||
```
|
||||
|
||||
**Step 3**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: my-fake-key # replace api_key with actual key
|
||||
litellm_settings:
|
||||
success_callback: ["traceloop"]
|
||||
success_callback: [ "traceloop" ]
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
**Step 4**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
|
|
|
@ -3,34 +3,38 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# ⚡ Best Practices for Production
|
||||
|
||||
Expected Performance in Production
|
||||
## 1. Use this config.yaml
|
||||
Use this config.yaml in production (with your own LLMs)
|
||||
|
||||
1 LiteLLM Uvicorn Worker on Kubernetes
|
||||
|
||||
| Description | Value |
|
||||
|--------------|-------|
|
||||
| Avg latency | `50ms` |
|
||||
| Median latency | `51ms` |
|
||||
| `/chat/completions` Requests/second | `35` |
|
||||
| `/chat/completions` Requests/minute | `2100` |
|
||||
| `/chat/completions` Requests/hour | `126K` |
|
||||
|
||||
|
||||
## 1. Switch off Debug Logging
|
||||
|
||||
Remove `set_verbose: True` from your config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # enter your own master key, ensure it starts with 'sk-'
|
||||
alerting: ["slack"] # Setup slack alerting - get alerts on LLM exceptions, Budget Alerts, Slow LLM Responses
|
||||
proxy_batch_write_at: 60 # Batch write spend updates every 60s
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
set_verbose: False # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||
```
|
||||
|
||||
You should only see the following level of details in logs on the proxy server
|
||||
Set slack webhook url in your env
|
||||
```shell
|
||||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
export SLACK_WEBHOOK_URL="https://hooks.slack.com/services/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"
|
||||
```
|
||||
|
||||
:::info
|
||||
|
||||
Need Help or want dedicated support ? Talk to a founder [here]: (https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## 2. On Kubernetes - Use 1 Uvicorn worker [Suggested CMD]
|
||||
|
||||
Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
||||
|
@ -40,21 +44,12 @@ Use this Docker `CMD`. This will start the proxy with 1 Uvicorn Async Worker
|
|||
CMD ["--port", "4000", "--config", "./proxy_server_config.yaml"]
|
||||
```
|
||||
|
||||
## 3. Batch write spend updates every 60s
|
||||
|
||||
The default proxy batch write is 10s. This is to make it easy to see spend when debugging locally.
|
||||
## 3. Use Redis 'port','host', 'password'. NOT 'redis_url'
|
||||
|
||||
In production, we recommend using a longer interval period of 60s. This reduces the number of connections used to make DB writes.
|
||||
If you decide to use Redis, DO NOT use 'redis_url'. We recommend usig redis port, host, and password params.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
proxy_batch_write_at: 60 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
```
|
||||
|
||||
## 4. use Redis 'port','host', 'password'. NOT 'redis_url'
|
||||
|
||||
When connecting to Redis use redis port, host, and password params. Not 'redis_url'. We've seen a 80 RPS difference between these 2 approaches when using the async redis client.
|
||||
`redis_url`is 80 RPS slower
|
||||
|
||||
This is still something we're investigating. Keep track of it [here](https://github.com/BerriAI/litellm/issues/3188)
|
||||
|
||||
|
@ -69,103 +64,31 @@ router_settings:
|
|||
redis_password: os.environ/REDIS_PASSWORD
|
||||
```
|
||||
|
||||
## 5. Switch off resetting budgets
|
||||
## Extras
|
||||
### Expected Performance in Production
|
||||
|
||||
Add this to your config.yaml. (Only spend per Key, User and Team will be tracked - spend per API Call will not be written to the LiteLLM Database)
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_reset_budget: true
|
||||
```
|
||||
1 LiteLLM Uvicorn Worker on Kubernetes
|
||||
|
||||
## 6. Move spend logs to separate server (BETA)
|
||||
|
||||
Writing each spend log to the db can slow down your proxy. In testing we saw a 70% improvement in median response time, by moving writing spend logs to a separate server.
|
||||
|
||||
👉 [LiteLLM Spend Logs Server](https://github.com/BerriAI/litellm/tree/main/litellm-js/spend-logs)
|
||||
| Description | Value |
|
||||
|--------------|-------|
|
||||
| Avg latency | `50ms` |
|
||||
| Median latency | `51ms` |
|
||||
| `/chat/completions` Requests/second | `35` |
|
||||
| `/chat/completions` Requests/minute | `2100` |
|
||||
| `/chat/completions` Requests/hour | `126K` |
|
||||
|
||||
|
||||
**Spend Logs**
|
||||
This is a log of the key, tokens, model, and latency for each call on the proxy.
|
||||
### Verifying Debugging logs are off
|
||||
|
||||
[**Full Payload**](https://github.com/BerriAI/litellm/blob/8c9623a6bc4ad9da0a2dac64249a60ed8da719e8/litellm/proxy/utils.py#L1769)
|
||||
|
||||
|
||||
**1. Start the spend logs server**
|
||||
|
||||
```bash
|
||||
docker run -p 3000:3000 \
|
||||
-e DATABASE_URL="postgres://.." \
|
||||
ghcr.io/berriai/litellm-spend_logs:main-latest
|
||||
|
||||
# RUNNING on http://0.0.0.0:3000
|
||||
```
|
||||
|
||||
**2. Connect to proxy**
|
||||
|
||||
|
||||
Example litellm_config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
```
|
||||
|
||||
Add `SPEND_LOGS_URL` as an environment variable when starting the proxy
|
||||
|
||||
```bash
|
||||
docker run \
|
||||
-v $(pwd)/litellm_config.yaml:/app/config.yaml \
|
||||
-e DATABASE_URL="postgresql://.." \
|
||||
-e SPEND_LOGS_URL="http://host.docker.internal:3000" \ # 👈 KEY CHANGE
|
||||
-p 4000:4000 \
|
||||
ghcr.io/berriai/litellm:main-latest \
|
||||
--config /app/config.yaml --detailed_debug
|
||||
|
||||
# Running on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test Proxy!**
|
||||
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Be helpful"},
|
||||
{"role": "user", "content": "What do you know?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
In your LiteLLM Spend Logs Server, you should see
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```
|
||||
Received and stored 1 logs. Total logs in memory: 1
|
||||
...
|
||||
Flushed 1 log to the DB.
|
||||
You should only see the following level of details in logs on the proxy server
|
||||
```shell
|
||||
# INFO: 192.168.2.205:11774 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:34717 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
# INFO: 192.168.2.205:29734 - "POST /chat/completions HTTP/1.1" 200 OK
|
||||
```
|
||||
|
||||
|
||||
### Machine Specification
|
||||
|
||||
A t2.micro should be sufficient to handle 1k logs / minute on this server.
|
||||
|
||||
This consumes at max 120MB, and <0.1 vCPU.
|
||||
|
||||
## Machine Specifications to Deploy LiteLLM
|
||||
### Machine Specifications to Deploy LiteLLM
|
||||
|
||||
| Service | Spec | CPUs | Memory | Architecture | Version|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|
@ -173,7 +96,7 @@ This consumes at max 120MB, and <0.1 vCPU.
|
|||
| Redis Cache | - | - | - | - | 7.0+ Redis Engine|
|
||||
|
||||
|
||||
## Reference Kubernetes Deployment YAML
|
||||
### Reference Kubernetes Deployment YAML
|
||||
|
||||
Reference Kubernetes `deployment.yaml` that was load tested by us
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ This is a new feature, and subject to changes based on feedback.
|
|||
### Step 1. Setup Proxy
|
||||
|
||||
- `JWT_PUBLIC_KEY_URL`: This is the public keys endpoint of your OpenID provider. Typically it's `{openid-provider-base-url}/.well-known/openid-configuration/jwks`. For Keycloak it's `{keycloak_base_url}/realms/{your-realm}/protocol/openid-connect/certs`.
|
||||
- `JWT_AUDIENCE`: This is the audience used for decoding the JWT. If not set, the decode step will not verify the audience.
|
||||
|
||||
```bash
|
||||
export JWT_PUBLIC_KEY_URL="" # "https://demo.duendesoftware.com/.well-known/openid-configuration/jwks"
|
||||
|
|
|
@ -12,8 +12,8 @@ Requirements:
|
|||
|
||||
You can set budgets at 3 levels:
|
||||
- For the proxy
|
||||
- For a user
|
||||
- For a 'user' passed to `/chat/completions`, `/embeddings` etc
|
||||
- For an internal user
|
||||
- For an end-user
|
||||
- For a key
|
||||
- For a key (model specific budgets)
|
||||
|
||||
|
@ -58,7 +58,7 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
}'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="per-user" label="For User">
|
||||
<TabItem value="per-user" label="For Internal User">
|
||||
|
||||
Apply a budget across multiple keys.
|
||||
|
||||
|
@ -165,12 +165,12 @@ curl --location 'http://localhost:4000/team/new' \
|
|||
}
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="per-user-chat" label="For 'user' passed to /chat/completions">
|
||||
<TabItem value="per-user-chat" label="For End User">
|
||||
|
||||
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
|
||||
|
||||
**Step 1. Modify config.yaml**
|
||||
Define `litellm.max_user_budget`
|
||||
Define `litellm.max_end_user_budget`
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
@ -328,7 +328,7 @@ You can set:
|
|||
- max parallel requests
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="per-user" label="Per User">
|
||||
<TabItem value="per-user" label="Per Internal User">
|
||||
|
||||
Use `/user/new`, to persist rate limits across multiple keys.
|
||||
|
||||
|
@ -408,7 +408,7 @@ curl --location 'http://localhost:4000/user/new' \
|
|||
```
|
||||
|
||||
|
||||
## Create new keys for existing user
|
||||
## Create new keys for existing internal user
|
||||
|
||||
Just include user_id in the `/key/generate` request.
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ print(response)
|
|||
- `router.aimage_generation()` - async image generation calls
|
||||
|
||||
## Advanced - Routing Strategies
|
||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based
|
||||
#### Routing Strategies - Weighted Pick, Rate Limit Aware, Least Busy, Latency Based, Cost Based
|
||||
|
||||
Router provides 4 strategies for routing your calls across multiple deployments:
|
||||
|
||||
|
@ -467,6 +467,101 @@ async def router_acompletion():
|
|||
asyncio.run(router_acompletion())
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="lowest-cost" label="Lowest Cost Routing (Async)">
|
||||
|
||||
Picks a deployment based on the lowest cost
|
||||
|
||||
How this works:
|
||||
- Get all healthy deployments
|
||||
- Select all deployments that are under their provided `rpm/tpm` limits
|
||||
- For each deployment check if `litellm_param["model"]` exists in [`litellm_model_cost_map`](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json)
|
||||
- if deployment does not exist in `litellm_model_cost_map` -> use deployment_cost= `$1`
|
||||
- Select deployment with lowest cost
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
import asyncio
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {"id": "openai-gpt-4"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "groq/llama3-8b-8192"},
|
||||
"model_info": {"id": "groq-llama"},
|
||||
},
|
||||
]
|
||||
|
||||
# init router
|
||||
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||
async def router_acompletion():
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
)
|
||||
print(response)
|
||||
|
||||
print(response._hidden_params["model_id"]) # expect groq-llama, since groq/llama has lowest cost
|
||||
return response
|
||||
|
||||
asyncio.run(router_acompletion())
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### Using Custom Input/Output pricing
|
||||
|
||||
Set `litellm_params["input_cost_per_token"]` and `litellm_params["output_cost_per_token"]` for using custom pricing when routing
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"input_cost_per_token": 0.00003,
|
||||
"output_cost_per_token": 0.00003,
|
||||
},
|
||||
"model_info": {"id": "chatgpt-v-experimental"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-1",
|
||||
"input_cost_per_token": 0.000000001,
|
||||
"output_cost_per_token": 0.00000001,
|
||||
},
|
||||
"model_info": {"id": "chatgpt-v-1"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-5",
|
||||
"input_cost_per_token": 10,
|
||||
"output_cost_per_token": 12,
|
||||
},
|
||||
"model_info": {"id": "chatgpt-v-5"},
|
||||
},
|
||||
]
|
||||
# init router
|
||||
router = Router(model_list=model_list, routing_strategy="cost-based-routing")
|
||||
async def router_acompletion():
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}]
|
||||
)
|
||||
print(response)
|
||||
|
||||
print(response._hidden_params["model_id"]) # expect chatgpt-v-1, since chatgpt-v-1 has lowest cost
|
||||
return response
|
||||
|
||||
asyncio.run(router_acompletion())
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
@ -616,6 +711,57 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
|||
print(f"response: {response}")
|
||||
```
|
||||
|
||||
#### Retries based on Error Type
|
||||
|
||||
Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
|
||||
|
||||
Example:
|
||||
- 4 retries for `ContentPolicyViolationError`
|
||||
- 0 retries for `RateLimitErrors`
|
||||
|
||||
Example Usage
|
||||
|
||||
```python
|
||||
from litellm.router import RetryPolicy
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
BadRequestErrorRetries=1,
|
||||
TimeoutErrorRetries=2,
|
||||
RateLimitErrorRetries=3,
|
||||
)
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "bad-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
},
|
||||
],
|
||||
retry_policy=retry_policy,
|
||||
)
|
||||
|
||||
response = await router.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
### Fallbacks
|
||||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
@ -940,6 +1086,46 @@ async def test_acompletion_caching_on_router_caching_groups():
|
|||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||
```
|
||||
|
||||
## Alerting 🚨
|
||||
|
||||
Send alerts to slack / your webhook url for the following events
|
||||
- LLM API Exceptions
|
||||
- Slow LLM Responses
|
||||
|
||||
Get a slack webhook url from https://api.slack.com/messaging/webhooks
|
||||
|
||||
#### Usage
|
||||
Initialize an `AlertingConfig` and pass it to `litellm.Router`. The following code will trigger an alert because `api_key=bad-key` which is invalid
|
||||
|
||||
```python
|
||||
from litellm.router import AlertingConfig
|
||||
import litellm
|
||||
import os
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "bad_key",
|
||||
},
|
||||
}
|
||||
],
|
||||
alerting_config= AlertingConfig(
|
||||
alerting_threshold=10, # threshold for slow / hanging llm responses (in seconds). Defaults to 300 seconds
|
||||
webhook_url= os.getenv("SLACK_WEBHOOK_URL") # webhook you want to send alerts to
|
||||
),
|
||||
)
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
except:
|
||||
pass
|
||||
```
|
||||
|
||||
## Track cost for Azure Deployments
|
||||
|
||||
**Problem**: Azure returns `gpt-4` in the response when `azure/gpt-4-1106-preview` is used. This leads to inaccurate cost tracking
|
||||
|
@ -1108,6 +1294,7 @@ def __init__(
|
|||
"least-busy",
|
||||
"usage-based-routing",
|
||||
"latency-based-routing",
|
||||
"cost-based-routing",
|
||||
] = "simple-shuffle",
|
||||
|
||||
## DEBUGGING ##
|
||||
|
|
|
@ -50,6 +50,7 @@ const sidebars = {
|
|||
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||
},
|
||||
"proxy/team_based_routing",
|
||||
"proxy/customer_routing",
|
||||
"proxy/ui",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/token_auth",
|
||||
|
@ -131,9 +132,13 @@ const sidebars = {
|
|||
"providers/cohere",
|
||||
"providers/anyscale",
|
||||
"providers/huggingface",
|
||||
"providers/watsonx",
|
||||
"providers/predibase",
|
||||
"providers/triton-inference-server",
|
||||
"providers/ollama",
|
||||
"providers/perplexity",
|
||||
"providers/groq",
|
||||
"providers/deepseek",
|
||||
"providers/fireworks_ai",
|
||||
"providers/vllm",
|
||||
"providers/xinference",
|
||||
|
@ -149,7 +154,7 @@ const sidebars = {
|
|||
"providers/openrouter",
|
||||
"providers/custom_openai_proxy",
|
||||
"providers/petals",
|
||||
"providers/watsonx",
|
||||
|
||||
],
|
||||
},
|
||||
"proxy/custom_pricing",
|
||||
|
|
|
@ -291,7 +291,7 @@ def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
|
|||
|
||||
|
||||
def _forecast_daily_cost(data: list):
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
if len(data) == 0:
|
||||
|
|
108
index.yaml
Normal file
108
index.yaml
Normal file
|
@ -0,0 +1,108 @@
|
|||
apiVersion: v1
|
||||
entries:
|
||||
litellm-helm:
|
||||
- apiVersion: v2
|
||||
appVersion: v1.35.38
|
||||
created: "2024-05-06T10:22:24.384392-07:00"
|
||||
dependencies:
|
||||
- condition: db.deployStandalone
|
||||
name: postgresql
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: '>=13.3.0'
|
||||
- condition: redis.enabled
|
||||
name: redis
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: '>=18.0.0'
|
||||
description: Call all LLM APIs using the OpenAI format
|
||||
digest: 60f0cfe9e7c1087437cb35f6fb7c43c3ab2be557b6d3aec8295381eb0dfa760f
|
||||
name: litellm-helm
|
||||
type: application
|
||||
urls:
|
||||
- litellm-helm-0.2.0.tgz
|
||||
version: 0.2.0
|
||||
postgresql:
|
||||
- annotations:
|
||||
category: Database
|
||||
images: |
|
||||
- name: os-shell
|
||||
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||
- name: postgres-exporter
|
||||
image: docker.io/bitnami/postgres-exporter:0.15.0-debian-12-r14
|
||||
- name: postgresql
|
||||
image: docker.io/bitnami/postgresql:16.2.0-debian-12-r6
|
||||
licenses: Apache-2.0
|
||||
apiVersion: v2
|
||||
appVersion: 16.2.0
|
||||
created: "2024-05-06T10:22:24.387717-07:00"
|
||||
dependencies:
|
||||
- name: common
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
tags:
|
||||
- bitnami-common
|
||||
version: 2.x.x
|
||||
description: PostgreSQL (Postgres) is an open source object-relational database
|
||||
known for reliability and data integrity. ACID-compliant, it supports foreign
|
||||
keys, joins, views, triggers and stored procedures.
|
||||
digest: 3c8125526b06833df32e2f626db34aeaedb29d38f03d15349db6604027d4a167
|
||||
home: https://bitnami.com
|
||||
icon: https://bitnami.com/assets/stacks/postgresql/img/postgresql-stack-220x234.png
|
||||
keywords:
|
||||
- postgresql
|
||||
- postgres
|
||||
- database
|
||||
- sql
|
||||
- replication
|
||||
- cluster
|
||||
maintainers:
|
||||
- name: VMware, Inc.
|
||||
url: https://github.com/bitnami/charts
|
||||
name: postgresql
|
||||
sources:
|
||||
- https://github.com/bitnami/charts/tree/main/bitnami/postgresql
|
||||
urls:
|
||||
- charts/postgresql-14.3.1.tgz
|
||||
version: 14.3.1
|
||||
redis:
|
||||
- annotations:
|
||||
category: Database
|
||||
images: |
|
||||
- name: kubectl
|
||||
image: docker.io/bitnami/kubectl:1.29.2-debian-12-r3
|
||||
- name: os-shell
|
||||
image: docker.io/bitnami/os-shell:12-debian-12-r16
|
||||
- name: redis
|
||||
image: docker.io/bitnami/redis:7.2.4-debian-12-r9
|
||||
- name: redis-exporter
|
||||
image: docker.io/bitnami/redis-exporter:1.58.0-debian-12-r4
|
||||
- name: redis-sentinel
|
||||
image: docker.io/bitnami/redis-sentinel:7.2.4-debian-12-r7
|
||||
licenses: Apache-2.0
|
||||
apiVersion: v2
|
||||
appVersion: 7.2.4
|
||||
created: "2024-05-06T10:22:24.391903-07:00"
|
||||
dependencies:
|
||||
- name: common
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
tags:
|
||||
- bitnami-common
|
||||
version: 2.x.x
|
||||
description: Redis(R) is an open source, advanced key-value store. It is often
|
||||
referred to as a data structure server since keys can contain strings, hashes,
|
||||
lists, sets and sorted sets.
|
||||
digest: b2fa1835f673a18002ca864c54fadac3c33789b26f6c5e58e2851b0b14a8f984
|
||||
home: https://bitnami.com
|
||||
icon: https://bitnami.com/assets/stacks/redis/img/redis-stack-220x234.png
|
||||
keywords:
|
||||
- redis
|
||||
- keyvalue
|
||||
- database
|
||||
maintainers:
|
||||
- name: VMware, Inc.
|
||||
url: https://github.com/bitnami/charts
|
||||
name: redis
|
||||
sources:
|
||||
- https://github.com/bitnami/charts/tree/main/bitnami/redis
|
||||
urls:
|
||||
- charts/redis-18.19.1.tgz
|
||||
version: 18.19.1
|
||||
generated: "2024-05-06T10:22:24.375026-07:00"
|
BIN
litellm-helm-0.2.0.tgz
Normal file
BIN
litellm-helm-0.2.0.tgz
Normal file
Binary file not shown.
|
@ -1,3 +1,7 @@
|
|||
### Hide pydantic namespace conflict warnings globally ###
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
|
@ -71,9 +75,11 @@ maritalk_key: Optional[str] = None
|
|||
ai21_key: Optional[str] = None
|
||||
ollama_key: Optional[str] = None
|
||||
openrouter_key: Optional[str] = None
|
||||
predibase_key: Optional[str] = None
|
||||
huggingface_key: Optional[str] = None
|
||||
vertex_project: Optional[str] = None
|
||||
vertex_location: Optional[str] = None
|
||||
predibase_tenant_id: Optional[str] = None
|
||||
togetherai_api_key: Optional[str] = None
|
||||
cloudflare_api_key: Optional[str] = None
|
||||
baseten_key: Optional[str] = None
|
||||
|
@ -361,6 +367,7 @@ openai_compatible_endpoints: List = [
|
|||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1",
|
||||
"api.groq.com/openai/v1",
|
||||
"api.deepseek.com/v1",
|
||||
"api.together.xyz/v1",
|
||||
]
|
||||
|
||||
|
@ -369,6 +376,7 @@ openai_compatible_providers: List = [
|
|||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"deepinfra",
|
||||
"perplexity",
|
||||
"xinference",
|
||||
|
@ -523,12 +531,15 @@ provider_list: List = [
|
|||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"deepseek",
|
||||
"maritalk",
|
||||
"voyage",
|
||||
"cloudflare",
|
||||
"xinference",
|
||||
"fireworks_ai",
|
||||
"watsonx",
|
||||
"triton",
|
||||
"predibase",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
||||
|
@ -605,7 +616,6 @@ all_embedding_models = (
|
|||
####### IMAGE GENERATION MODELS ###################
|
||||
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
||||
|
||||
|
||||
from .timeout import timeout
|
||||
from .utils import (
|
||||
client,
|
||||
|
@ -613,6 +623,8 @@ from .utils import (
|
|||
get_optional_params,
|
||||
modify_integration,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
create_tokenizer,
|
||||
cost_per_token,
|
||||
completion_cost,
|
||||
supports_function_calling,
|
||||
|
@ -636,9 +648,11 @@ from .utils import (
|
|||
get_secret,
|
||||
get_supported_openai_params,
|
||||
get_api_base,
|
||||
get_first_chars_messages,
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
from .llms.predibase import PredibaseConfig
|
||||
from .llms.anthropic_text import AnthropicTextConfig
|
||||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere import CohereConfig
|
||||
|
@ -692,3 +706,4 @@ from .exceptions import (
|
|||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
from .router import Router
|
||||
from .assistants.main import *
|
||||
|
|
|
@ -10,8 +10,8 @@
|
|||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
import inspect
|
||||
import redis, litellm
|
||||
import redis.asyncio as async_redis
|
||||
import redis, litellm # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
|
|
495
litellm/assistants/main.py
Normal file
495
litellm/assistants/main.py
Normal file
|
@ -0,0 +1,495 @@
|
|||
# What is this?
|
||||
## Main file for assistants API logic
|
||||
from typing import Iterable
|
||||
import os
|
||||
import litellm
|
||||
from openai import OpenAI
|
||||
from litellm import client
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
from ..llms.openai import OpenAIAssistantsAPI
|
||||
from ..types.llms.openai import *
|
||||
from ..types.router import *
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_assistants_api = OpenAIAssistantsAPI()
|
||||
|
||||
### ASSISTANTS ###
|
||||
|
||||
|
||||
def get_assistants(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> SyncCursorPage[Assistant]:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[SyncCursorPage[Assistant]] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.get_assistants(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
### THREADS ###
|
||||
|
||||
|
||||
def create_thread(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
"""
|
||||
- get the llm provider
|
||||
- if openai - route it there
|
||||
- pass through relevant params
|
||||
|
||||
```
|
||||
from litellm import create_thread
|
||||
|
||||
create_thread(
|
||||
custom_llm_provider="openai",
|
||||
### OPTIONAL ###
|
||||
messages = {
|
||||
"role": "user",
|
||||
"content": "Hello, what is AI?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How does AI work? Explain it in simple terms."
|
||||
}]
|
||||
)
|
||||
```
|
||||
"""
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[Thread] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.create_thread(
|
||||
messages=messages,
|
||||
metadata=metadata,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def get_thread(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Thread:
|
||||
"""Get the thread object, given a thread_id"""
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[Thread] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.get_thread(
|
||||
thread_id=thread_id,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
### MESSAGES ###
|
||||
|
||||
|
||||
def add_message(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
role: Literal["user", "assistant"],
|
||||
content: str,
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIMessage:
|
||||
### COMMON OBJECTS ###
|
||||
message_data = MessageData(
|
||||
role=role, content=content, attachments=attachments, metadata=metadata
|
||||
)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[OpenAIMessage] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_messages(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.get_messages(
|
||||
thread_id=thread_id,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
### RUNS ###
|
||||
|
||||
|
||||
def run_thread(
|
||||
custom_llm_provider: Literal["openai"],
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||
client: Optional[OpenAI] = None,
|
||||
**kwargs,
|
||||
) -> Run:
|
||||
"""Run a given thread + assistant."""
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[Run] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_assistants_api.run_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
|
@ -10,7 +10,7 @@
|
|||
import os, json, time
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
import requests, threading
|
||||
import requests, threading # type: ignore
|
||||
from typing import Optional, Union, Literal
|
||||
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ class InMemoryCache(BaseCache):
|
|||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
|
@ -177,11 +177,18 @@ class RedisCache(BaseCache):
|
|||
try:
|
||||
# asyncio.get_running_loop().create_task(self.ping())
|
||||
result = asyncio.get_running_loop().create_task(self.ping())
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Error connecting to Async Redis client", extra={"error": str(e)}
|
||||
)
|
||||
|
||||
### SYNC HEALTH PING ###
|
||||
self.redis_client.ping()
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||
)
|
||||
|
||||
def init_async_client(self):
|
||||
from ._redis import get_redis_async_client
|
||||
|
@ -416,12 +423,12 @@ class RedisCache(BaseCache):
|
|||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||
await self.flush_cache_buffer() # logging done in here
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
result = await redis_client.incr(name=key, amount=value)
|
||||
result = await redis_client.incrbyfloat(name=key, amount=value)
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
@ -1375,18 +1382,41 @@ class DualCache(BaseCache):
|
|||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_batch_set_cache(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Batch write values to the cache
|
||||
"""
|
||||
print_verbose(
|
||||
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def async_increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
self, key, value: float, local_only: bool = False, **kwargs
|
||||
) -> float:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
Value - float - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
Returns - float - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
result: float = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -4,18 +4,30 @@ import datetime
|
|||
class AthinaLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.athina_api_key = os.getenv("ATHINA_API_KEY")
|
||||
self.headers = {
|
||||
"athina-api-key": self.athina_api_key,
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
|
||||
self.additional_keys = ["environment", "prompt_slug", "customer_id", "customer_user_id", "session_id", "external_reference_id", "context", "expected_response", "user_query"]
|
||||
self.additional_keys = [
|
||||
"environment",
|
||||
"prompt_slug",
|
||||
"customer_id",
|
||||
"customer_user_id",
|
||||
"session_id",
|
||||
"external_reference_id",
|
||||
"context",
|
||||
"expected_response",
|
||||
"user_query",
|
||||
]
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import json
|
||||
import traceback
|
||||
|
||||
try:
|
||||
response_json = response_obj.model_dump() if response_obj else {}
|
||||
data = {
|
||||
|
@ -23,32 +35,51 @@ class AthinaLogger:
|
|||
"request": kwargs,
|
||||
"response": response_json,
|
||||
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"completion_tokens": response_json.get("usage", {}).get("completion_tokens"),
|
||||
"completion_tokens": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
"total_tokens": response_json.get("usage", {}).get("total_tokens"),
|
||||
}
|
||||
|
||||
if type(end_time) == datetime.datetime and type(start_time) == datetime.datetime:
|
||||
data["response_time"] = int((end_time - start_time).total_seconds() * 1000)
|
||||
|
||||
if (
|
||||
type(end_time) == datetime.datetime
|
||||
and type(start_time) == datetime.datetime
|
||||
):
|
||||
data["response_time"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
if "messages" in kwargs:
|
||||
data["prompt"] = kwargs.get("messages", None)
|
||||
|
||||
# Directly add tools or functions if present
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
data.update((k, v) for k, v in optional_params.items() if k in ["tools", "functions"])
|
||||
data.update(
|
||||
(k, v)
|
||||
for k, v in optional_params.items()
|
||||
if k in ["tools", "functions"]
|
||||
)
|
||||
|
||||
# Add additional metadata keys
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
if metadata:
|
||||
for key in self.additional_keys:
|
||||
if key in metadata:
|
||||
data[key] = metadata[key]
|
||||
|
||||
response = requests.post(self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
||||
response = requests.post(
|
||||
self.athina_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(f"Athina Logger Error - {response.text}, {response.status_code}")
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Athina Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
||||
pass
|
||||
print_verbose(
|
||||
f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
import requests
|
||||
import requests # type: ignore
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
class GreenscaleLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
||||
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
|
||||
self.headers = {
|
||||
"api-key": self.greenscale_api_key,
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
|
||||
|
||||
|
@ -19,33 +21,48 @@ class GreenscaleLogger:
|
|||
data = {
|
||||
"modelId": kwargs.get("model"),
|
||||
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
|
||||
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"),
|
||||
"outputTokenCount": response_json.get("usage", {}).get(
|
||||
"completion_tokens"
|
||||
),
|
||||
}
|
||||
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
if type(end_time) == datetime and type(start_time) == datetime:
|
||||
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
|
||||
data["timestamp"] = datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
if type(end_time) == datetime and type(start_time) == datetime:
|
||||
data["invocationLatency"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
||||
# Add additional metadata keys to tags
|
||||
tags = []
|
||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
for key, value in metadata.items():
|
||||
if key.startswith("greenscale"):
|
||||
if key.startswith("greenscale"):
|
||||
if key == "greenscale_project":
|
||||
data["project"] = value
|
||||
elif key == "greenscale_application":
|
||||
data["application"] = value
|
||||
else:
|
||||
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)})
|
||||
|
||||
tags.append(
|
||||
{"key": key.replace("greenscale_", ""), "value": str(value)}
|
||||
)
|
||||
|
||||
data["tags"] = tags
|
||||
|
||||
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str))
|
||||
response = requests.post(
|
||||
self.greenscale_logging_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data, default=str),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}")
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {response.text}, {response.status_code}"
|
||||
)
|
||||
else:
|
||||
print_verbose(f"Greenscale Logger Succeeded - {response.text}")
|
||||
except Exception as e:
|
||||
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}")
|
||||
pass
|
||||
print_verbose(
|
||||
f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Helicone
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
|
|
@ -262,6 +262,23 @@ class LangFuseLogger:
|
|||
|
||||
try:
|
||||
tags = []
|
||||
try:
|
||||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
isinstance(value, list)
|
||||
or isinstance(value, dict)
|
||||
or isinstance(value, str)
|
||||
or isinstance(value, int)
|
||||
or isinstance(value, float)
|
||||
):
|
||||
new_metadata[key] = copy.deepcopy(value)
|
||||
metadata = new_metadata
|
||||
|
||||
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
|
||||
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
|
||||
|
@ -272,36 +289,9 @@ class LangFuseLogger:
|
|||
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
||||
|
||||
if supports_tags:
|
||||
metadata_tags = metadata.get("tags", [])
|
||||
metadata_tags = metadata.pop("tags", [])
|
||||
tags = metadata_tags
|
||||
|
||||
trace_name = metadata.get("trace_name", None)
|
||||
trace_id = metadata.get("trace_id", None)
|
||||
existing_trace_id = metadata.get("existing_trace_id", None)
|
||||
if trace_name is None and existing_trace_id is None:
|
||||
# just log `litellm-{call_type}` as the trace name
|
||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
||||
if existing_trace_id is not None:
|
||||
trace_params = {"id": existing_trace_id}
|
||||
else: # don't overwrite an existing trace
|
||||
trace_params = {
|
||||
"name": trace_name,
|
||||
"input": input,
|
||||
"user_id": metadata.get("trace_user_id", user_id),
|
||||
"id": trace_id,
|
||||
"session_id": metadata.get("session_id", None),
|
||||
}
|
||||
|
||||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = output
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
print_verbose(f"trace: {cost}")
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
|
@ -328,6 +318,67 @@ class LangFuseLogger:
|
|||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
session_id = clean_metadata.pop("session_id", None)
|
||||
trace_name = clean_metadata.pop("trace_name", None)
|
||||
trace_id = clean_metadata.pop("trace_id", None)
|
||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
||||
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
||||
|
||||
if trace_name is None and existing_trace_id is None:
|
||||
# just log `litellm-{call_type}` as the trace name
|
||||
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
|
||||
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
||||
if existing_trace_id is not None:
|
||||
trace_params = {"id": existing_trace_id}
|
||||
|
||||
# Update the following keys for this trace
|
||||
for metadata_param_key in update_trace_keys:
|
||||
trace_param_key = metadata_param_key.replace("trace_", "")
|
||||
if trace_param_key not in trace_params:
|
||||
updated_trace_value = clean_metadata.pop(
|
||||
metadata_param_key, None
|
||||
)
|
||||
if updated_trace_value is not None:
|
||||
trace_params[trace_param_key] = updated_trace_value
|
||||
|
||||
# Pop the trace specific keys that would have been popped if there were a new trace
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
clean_metadata.pop(key, None)
|
||||
|
||||
# Special keys that are found in the function arguments and not the metadata
|
||||
if "input" in update_trace_keys:
|
||||
trace_params["input"] = input
|
||||
if "output" in update_trace_keys:
|
||||
trace_params["output"] = output
|
||||
else: # don't overwrite an existing trace
|
||||
trace_params = {
|
||||
"id": trace_id,
|
||||
"name": trace_name,
|
||||
"session_id": session_id,
|
||||
"input": input,
|
||||
"version": clean_metadata.pop(
|
||||
"trace_version", clean_metadata.get("version", None)
|
||||
), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
|
||||
"user_id": user_id,
|
||||
}
|
||||
for key in list(
|
||||
filter(lambda key: key.startswith("trace_"), clean_metadata.keys())
|
||||
):
|
||||
trace_params[key.replace("trace_", "")] = clean_metadata.pop(
|
||||
key, None
|
||||
)
|
||||
|
||||
if level == "ERROR":
|
||||
trace_params["status_message"] = output
|
||||
else:
|
||||
trace_params["output"] = output
|
||||
|
||||
cost = kwargs.get("response_cost", None)
|
||||
print_verbose(f"trace: {cost}")
|
||||
|
||||
if (
|
||||
litellm._langfuse_default_tags is not None
|
||||
and isinstance(litellm._langfuse_default_tags, list)
|
||||
|
@ -387,7 +438,7 @@ class LangFuseLogger:
|
|||
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
||||
"total_cost": cost if supports_costs else None,
|
||||
}
|
||||
generation_name = metadata.get("generation_name", None)
|
||||
generation_name = clean_metadata.pop("generation_name", None)
|
||||
if generation_name is None:
|
||||
# just log `litellm-{call_type}` as the generation name
|
||||
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
|
||||
|
@ -402,7 +453,7 @@ class LangFuseLogger:
|
|||
|
||||
generation_params = {
|
||||
"name": generation_name,
|
||||
"id": metadata.get("generation_id", generation_id),
|
||||
"id": clean_metadata.pop("generation_id", generation_id),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"model": kwargs["model"],
|
||||
|
@ -412,10 +463,11 @@ class LangFuseLogger:
|
|||
"usage": usage,
|
||||
"metadata": clean_metadata,
|
||||
"level": level,
|
||||
"version": clean_metadata.pop("version", None),
|
||||
}
|
||||
|
||||
if supports_prompt:
|
||||
generation_params["prompt"] = metadata.get("prompt", None)
|
||||
generation_params["prompt"] = clean_metadata.pop("prompt", None)
|
||||
|
||||
if output is not None and isinstance(output, str) and level == "ERROR":
|
||||
generation_params["status_message"] = output
|
||||
|
@ -426,7 +478,7 @@ class LangFuseLogger:
|
|||
)
|
||||
|
||||
generation_client = trace.generation(**generation_params)
|
||||
|
||||
|
||||
return generation_client.trace_id, generation_id
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Langsmith
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests
|
||||
import dotenv, os # type: ignore
|
||||
import requests # type: ignore
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import asyncio
|
||||
import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
|
||||
def is_serializable(value):
|
||||
|
@ -79,8 +78,6 @@ class LangsmithLogger:
|
|||
except:
|
||||
response_obj = response_obj.dict() # type: ignore
|
||||
|
||||
print(f"response_obj: {response_obj}")
|
||||
|
||||
data = {
|
||||
"name": run_name,
|
||||
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
|
||||
|
@ -90,7 +87,6 @@ class LangsmithLogger:
|
|||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
print(f"data: {data}")
|
||||
|
||||
response = requests.post(
|
||||
"https://api.smith.langchain.com/runs",
|
||||
|
|
|
@ -4,7 +4,6 @@ from datetime import datetime, timezone
|
|||
import traceback
|
||||
import dotenv
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import packaging
|
||||
|
||||
|
@ -18,13 +17,33 @@ def parse_usage(usage):
|
|||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
def parse_tool_calls(tool_calls):
|
||||
if tool_calls is None:
|
||||
return None
|
||||
|
||||
def clean_tool_call(tool_call):
|
||||
|
||||
serialized = {
|
||||
"type": tool_call.type,
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
}
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
return [clean_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
def clean_message(message):
|
||||
# if is strin, return as is
|
||||
# if is string, return as is
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
|
@ -38,9 +57,7 @@ def parse_messages(input):
|
|||
|
||||
# Only add tool_calls and function_call to res if they are set
|
||||
if message.get("tool_calls"):
|
||||
serialized["tool_calls"] = message.get("tool_calls")
|
||||
if message.get("function_call"):
|
||||
serialized["function_call"] = message.get("function_call")
|
||||
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))
|
||||
|
||||
return serialized
|
||||
|
||||
|
@ -93,8 +110,13 @@ class LunaryLogger:
|
|||
print_verbose(f"Lunary Logging - Logging request for model {model}")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
if optional_params:
|
||||
# merge into extra
|
||||
extra = {**extra, **optional_params}
|
||||
|
||||
tags = litellm_params.pop("tags", None) or []
|
||||
|
||||
if extra:
|
||||
|
@ -104,7 +126,7 @@ class LunaryLogger:
|
|||
|
||||
# keep only serializable types
|
||||
for param, value in extra.items():
|
||||
if not isinstance(value, (str, int, bool, float)):
|
||||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||
try:
|
||||
extra[param] = str(value)
|
||||
except:
|
||||
|
@ -140,7 +162,7 @@ class LunaryLogger:
|
|||
metadata=metadata,
|
||||
runtime="litellm",
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
params=extra,
|
||||
)
|
||||
|
||||
self.lunary_client.track_event(
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
||||
|
||||
import dotenv, os, json
|
||||
import requests
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
@ -38,7 +37,7 @@ class OpenMeterLogger(CustomLogger):
|
|||
in the environment
|
||||
"""
|
||||
missing_keys = []
|
||||
if litellm.get_secret("OPENMETER_API_KEY", None) is None:
|
||||
if os.getenv("OPENMETER_API_KEY", None) is None:
|
||||
missing_keys.append("OPENMETER_API_KEY")
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
|
@ -60,47 +59,56 @@ class OpenMeterLogger(CustomLogger):
|
|||
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||
}
|
||||
|
||||
subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
|
||||
if not subject:
|
||||
raise Exception("OpenMeter: user is required")
|
||||
|
||||
return {
|
||||
"specversion": "1.0",
|
||||
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
||||
"id": call_id,
|
||||
"time": dt,
|
||||
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param
|
||||
"subject": subject,
|
||||
"source": "litellm-proxy",
|
||||
"data": {"model": model, "cost": cost, **usage},
|
||||
}
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = litellm.get_secret(
|
||||
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
|
||||
)
|
||||
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
||||
api_key = os.getenv("OPENMETER_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=_data,
|
||||
headers={
|
||||
"Content-Type": "application/cloudevents+json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
},
|
||||
)
|
||||
_headers = {
|
||||
"Content-Type": "application/cloudevents+json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
_url = litellm.get_secret(
|
||||
"OPENMETER_API_ENDPOINT", default_value="https://openmeter.cloud"
|
||||
)
|
||||
_url = os.getenv("OPENMETER_API_ENDPOINT", "https://openmeter.cloud")
|
||||
if _url.endswith("/"):
|
||||
_url += "api/v1/events"
|
||||
else:
|
||||
_url += "/api/v1/events"
|
||||
|
||||
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
||||
api_key = os.getenv("OPENMETER_API_KEY")
|
||||
|
||||
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||
_headers = {
|
||||
|
@ -117,7 +125,6 @@ class OpenMeterLogger(CustomLogger):
|
|||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"\nAn Exception Occurred - {str(e)}")
|
||||
if hasattr(response, "text"):
|
||||
print(f"\nError Message: {response.text}")
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# On success, log events to Prometheus
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
@ -19,7 +19,6 @@ class PrometheusLogger:
|
|||
**kwargs,
|
||||
):
|
||||
try:
|
||||
print(f"in init prometheus metrics")
|
||||
from prometheus_client import Counter
|
||||
|
||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
@ -183,7 +183,6 @@ class PrometheusServicesLogger:
|
|||
)
|
||||
|
||||
async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
print(f"received error payload: {payload.error}")
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
class PromptLayerLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
|
@ -32,7 +33,11 @@ class PromptLayerLogger:
|
|||
tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
|
||||
|
||||
# Remove "pl_tags" from metadata
|
||||
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in kwargs["litellm_params"]["metadata"].items()
|
||||
if k != "pl_tags"
|
||||
}
|
||||
|
||||
print_verbose(
|
||||
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -1,25 +1,82 @@
|
|||
#### What this does ####
|
||||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import copy
|
||||
import traceback
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm
|
||||
import litellm, threading
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
from litellm.caching import DualCache
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
from datetime import datetime as dt, timedelta
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import random
|
||||
|
||||
|
||||
class SlackAlerting:
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class SlackAlertingArgs(LiteLLMBase):
|
||||
daily_report_frequency: int = 12 * 60 * 60 # 12 hours
|
||||
report_check_interval: int = 5 * 60 # 5 minutes
|
||||
|
||||
|
||||
class DeploymentMetrics(LiteLLMBase):
|
||||
"""
|
||||
Metrics per deployment, stored in cache
|
||||
|
||||
Used for daily reporting
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""id of deployment in router model list"""
|
||||
|
||||
failed_request: bool
|
||||
"""did it fail the request?"""
|
||||
|
||||
latency_per_output_token: Optional[float]
|
||||
"""latency/output token of deployment"""
|
||||
|
||||
updated_at: dt
|
||||
"""Current time of deployment being updated"""
|
||||
|
||||
|
||||
class SlackAlertingCacheKeys(Enum):
|
||||
"""
|
||||
Enum for deployment daily metrics keys - {deployment_id}:{enum}
|
||||
"""
|
||||
|
||||
failed_requests_key = "failed_requests_daily_metrics"
|
||||
latency_key = "latency_daily_metrics"
|
||||
report_sent_key = "daily_metrics_report_sent"
|
||||
|
||||
|
||||
class SlackAlerting(CustomLogger):
|
||||
"""
|
||||
Class for sending Slack Alerts
|
||||
"""
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
alerting_threshold: float = 300,
|
||||
internal_usage_cache: Optional[DualCache] = None,
|
||||
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
|
||||
alerting: Optional[List] = [],
|
||||
alert_types: Optional[
|
||||
List[
|
||||
|
@ -29,6 +86,7 @@ class SlackAlerting:
|
|||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
]
|
||||
]
|
||||
] = [
|
||||
|
@ -37,31 +95,23 @@ class SlackAlerting:
|
|||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
],
|
||||
alert_to_webhook_url: Optional[
|
||||
Dict
|
||||
] = None, # if user wants to separate alerts to diff channels
|
||||
alerting_args={},
|
||||
default_webhook_url: Optional[str] = None,
|
||||
):
|
||||
self.alerting_threshold = alerting_threshold
|
||||
self.alerting = alerting
|
||||
self.alert_types = alert_types
|
||||
self.internal_usage_cache = DualCache()
|
||||
self.internal_usage_cache = internal_usage_cache or DualCache()
|
||||
self.async_http_handler = AsyncHTTPHandler()
|
||||
self.alert_to_webhook_url = alert_to_webhook_url
|
||||
self.langfuse_logger = None
|
||||
|
||||
try:
|
||||
from litellm.integrations.langfuse import LangFuseLogger
|
||||
|
||||
self.langfuse_logger = LangFuseLogger(
|
||||
os.getenv("LANGFUSE_PUBLIC_KEY"),
|
||||
os.getenv("LANGFUSE_SECRET_KEY"),
|
||||
flush_interval=1,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
pass
|
||||
self.is_running = False
|
||||
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||
self.default_webhook_url = default_webhook_url
|
||||
|
||||
def update_values(
|
||||
self,
|
||||
|
@ -69,6 +119,7 @@ class SlackAlerting:
|
|||
alerting_threshold: Optional[float] = None,
|
||||
alert_types: Optional[List] = None,
|
||||
alert_to_webhook_url: Optional[Dict] = None,
|
||||
alerting_args: Optional[Dict] = None,
|
||||
):
|
||||
if alerting is not None:
|
||||
self.alerting = alerting
|
||||
|
@ -76,7 +127,8 @@ class SlackAlerting:
|
|||
self.alerting_threshold = alerting_threshold
|
||||
if alert_types is not None:
|
||||
self.alert_types = alert_types
|
||||
|
||||
if alerting_args is not None:
|
||||
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
||||
if alert_to_webhook_url is not None:
|
||||
# update the dict
|
||||
if self.alert_to_webhook_url is None:
|
||||
|
@ -103,72 +155,23 @@ class SlackAlerting:
|
|||
|
||||
def _add_langfuse_trace_id_to_alert(
|
||||
self,
|
||||
request_info: str,
|
||||
request_data: Optional[dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
||||
start_time: Optional[datetime.datetime] = None,
|
||||
end_time: Optional[datetime.datetime] = None,
|
||||
):
|
||||
import uuid
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
"""
|
||||
# do nothing for now
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("metadata", {}).get("trace_id", None) is not None
|
||||
):
|
||||
trace_id = request_data["metadata"]["trace_id"]
|
||||
if litellm.utils.langFuseLogger is not None:
|
||||
base_url = litellm.utils.langFuseLogger.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
return None
|
||||
|
||||
# For now: do nothing as we're debugging why this is not working as expected
|
||||
if request_data is not None:
|
||||
trace_id = request_data.get("metadata", {}).get(
|
||||
"trace_id", None
|
||||
) # get langfuse trace id
|
||||
if trace_id is None:
|
||||
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
||||
request_data["metadata"]["trace_id"] = trace_id
|
||||
elif kwargs is not None:
|
||||
_litellm_params = kwargs.get("litellm_params", {})
|
||||
trace_id = _litellm_params.get("metadata", {}).get(
|
||||
"trace_id", None
|
||||
) # get langfuse trace id
|
||||
if trace_id is None:
|
||||
trace_id = "litellm-alert-trace-" + str(uuid.uuid4())
|
||||
_litellm_params["metadata"]["trace_id"] = trace_id
|
||||
|
||||
# Log hanging request as an error on langfuse
|
||||
if type == "hanging_request":
|
||||
if self.langfuse_logger is not None:
|
||||
_logging_kwargs = copy.deepcopy(request_data)
|
||||
if _logging_kwargs is None:
|
||||
_logging_kwargs = {}
|
||||
_logging_kwargs["litellm_params"] = {}
|
||||
request_data = request_data or {}
|
||||
_logging_kwargs["litellm_params"]["metadata"] = request_data.get(
|
||||
"metadata", {}
|
||||
)
|
||||
# log to langfuse in a separate thread
|
||||
import threading
|
||||
|
||||
threading.Thread(
|
||||
target=self.langfuse_logger.log_event,
|
||||
args=(
|
||||
_logging_kwargs,
|
||||
None,
|
||||
start_time,
|
||||
end_time,
|
||||
None,
|
||||
print,
|
||||
"ERROR",
|
||||
"Requests is hanging",
|
||||
),
|
||||
).start()
|
||||
|
||||
_langfuse_host = os.environ.get("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
||||
_langfuse_project_id = os.environ.get("LANGFUSE_PROJECT_ID")
|
||||
|
||||
# langfuse urls look like: https://us.cloud.langfuse.com/project/************/traces/litellm-alert-trace-ididi9dk-09292-************
|
||||
|
||||
_langfuse_url = (
|
||||
f"{_langfuse_host}/project/{_langfuse_project_id}/traces/{trace_id}"
|
||||
)
|
||||
request_info += f"\n🪢 Langfuse Trace: {_langfuse_url}"
|
||||
return request_info
|
||||
|
||||
def _response_taking_too_long_callback(
|
||||
def _response_taking_too_long_callback_helper(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
start_time,
|
||||
|
@ -233,7 +236,7 @@ class SlackAlerting:
|
|||
return
|
||||
|
||||
time_difference_float, model, api_base, messages = (
|
||||
self._response_taking_too_long_callback(
|
||||
self._response_taking_too_long_callback_helper(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -242,10 +245,6 @@ class SlackAlerting:
|
|||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
||||
if time_difference_float > self.alerting_threshold:
|
||||
if "langfuse" in litellm.success_callback:
|
||||
request_info = self._add_langfuse_trace_id_to_alert(
|
||||
request_info=request_info, kwargs=kwargs, type="slow_response"
|
||||
)
|
||||
# add deployment latencies to alert
|
||||
if (
|
||||
kwargs is not None
|
||||
|
@ -253,6 +252,9 @@ class SlackAlerting:
|
|||
and "metadata" in kwargs["litellm_params"]
|
||||
):
|
||||
_metadata = kwargs["litellm_params"]["metadata"]
|
||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||
request_info=request_info, metadata=_metadata
|
||||
)
|
||||
|
||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||
metadata=_metadata
|
||||
|
@ -267,8 +269,178 @@ class SlackAlerting:
|
|||
alert_type="llm_too_slow",
|
||||
)
|
||||
|
||||
async def log_failure_event(self, original_exception: Exception):
|
||||
pass
|
||||
async def async_update_daily_reports(
|
||||
self, deployment_metrics: DeploymentMetrics
|
||||
) -> int:
|
||||
"""
|
||||
Store the perf by deployment in cache
|
||||
- Number of failed requests per deployment
|
||||
- Latency / output tokens per deployment
|
||||
|
||||
'deployment_id:daily_metrics:failed_requests'
|
||||
'deployment_id:daily_metrics:latency_per_output_token'
|
||||
|
||||
Returns
|
||||
int - count of metrics set (1 - if just latency, 2 - if failed + latency)
|
||||
"""
|
||||
|
||||
return_val = 0
|
||||
try:
|
||||
## FAILED REQUESTS ##
|
||||
if deployment_metrics.failed_request:
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key="{}:{}".format(
|
||||
deployment_metrics.id,
|
||||
SlackAlertingCacheKeys.failed_requests_key.value,
|
||||
),
|
||||
value=1,
|
||||
)
|
||||
|
||||
return_val += 1
|
||||
|
||||
## LATENCY ##
|
||||
if deployment_metrics.latency_per_output_token is not None:
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key="{}:{}".format(
|
||||
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
|
||||
),
|
||||
value=deployment_metrics.latency_per_output_token,
|
||||
)
|
||||
|
||||
return_val += 1
|
||||
|
||||
return return_val
|
||||
except Exception as e:
|
||||
return 0
|
||||
|
||||
async def send_daily_reports(self, router) -> bool:
|
||||
"""
|
||||
Send a daily report on:
|
||||
- Top 5 deployments with most failed requests
|
||||
- Top 5 slowest deployments (normalized by latency/output tokens)
|
||||
|
||||
Get the value from redis cache (if available) or in-memory and send it
|
||||
|
||||
Cleanup:
|
||||
- reset values in cache -> prevent memory leak
|
||||
|
||||
Returns:
|
||||
True -> if successfuly sent
|
||||
False -> if not sent
|
||||
"""
|
||||
|
||||
ids = router.get_model_ids()
|
||||
|
||||
# get keys
|
||||
failed_request_keys = [
|
||||
"{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
|
||||
for id in ids
|
||||
]
|
||||
latency_keys = [
|
||||
"{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
|
||||
]
|
||||
|
||||
combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
|
||||
|
||||
combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=combined_metrics_keys
|
||||
) # [1, 2, None, ..]
|
||||
|
||||
all_none = True
|
||||
for val in combined_metrics_values:
|
||||
if val is not None:
|
||||
all_none = False
|
||||
|
||||
if all_none:
|
||||
return False
|
||||
|
||||
failed_request_values = combined_metrics_values[
|
||||
: len(failed_request_keys)
|
||||
] # # [1, 2, None, ..]
|
||||
latency_values = combined_metrics_values[len(failed_request_keys) :]
|
||||
|
||||
# find top 5 failed
|
||||
## Replace None values with a placeholder value (-1 in this case)
|
||||
placeholder_value = 0
|
||||
replaced_failed_values = [
|
||||
value if value is not None else placeholder_value
|
||||
for value in failed_request_values
|
||||
]
|
||||
|
||||
## Get the indices of top 5 keys with the highest numerical values (ignoring None values)
|
||||
top_5_failed = sorted(
|
||||
range(len(replaced_failed_values)),
|
||||
key=lambda i: replaced_failed_values[i],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
|
||||
# find top 5 slowest
|
||||
# Replace None values with a placeholder value (-1 in this case)
|
||||
placeholder_value = 0
|
||||
replaced_slowest_values = [
|
||||
value if value is not None else placeholder_value
|
||||
for value in latency_values
|
||||
]
|
||||
|
||||
# Get the indices of top 5 values with the highest numerical values (ignoring None values)
|
||||
top_5_slowest = sorted(
|
||||
range(len(replaced_slowest_values)),
|
||||
key=lambda i: replaced_slowest_values[i],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
|
||||
# format alert -> return the litellm model name + api base
|
||||
message = f"\n\nHere are today's key metrics 📈: \n\n"
|
||||
|
||||
message += "\n\n*❗️ Top 5 Deployments with Most Failed Requests:*\n\n"
|
||||
for i in range(len(top_5_failed)):
|
||||
key = failed_request_keys[top_5_failed[i]].split(":")[0]
|
||||
_deployment = router.get_model_info(key)
|
||||
if isinstance(_deployment, dict):
|
||||
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||
else:
|
||||
return False
|
||||
|
||||
api_base = litellm.get_api_base(
|
||||
model=deployment_name,
|
||||
optional_params=(
|
||||
_deployment["litellm_params"] if _deployment is not None else {}
|
||||
),
|
||||
)
|
||||
if api_base is None:
|
||||
api_base = ""
|
||||
value = replaced_failed_values[top_5_failed[i]]
|
||||
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
|
||||
|
||||
message += "\n\n*😅 Top 5 Slowest Deployments:*\n\n"
|
||||
for i in range(len(top_5_slowest)):
|
||||
key = latency_keys[top_5_slowest[i]].split(":")[0]
|
||||
_deployment = router.get_model_info(key)
|
||||
if _deployment is not None:
|
||||
deployment_name = _deployment["litellm_params"].get("model", "")
|
||||
else:
|
||||
deployment_name = ""
|
||||
api_base = litellm.get_api_base(
|
||||
model=deployment_name,
|
||||
optional_params=(
|
||||
_deployment["litellm_params"] if _deployment is not None else {}
|
||||
),
|
||||
)
|
||||
value = round(replaced_slowest_values[top_5_slowest[i]], 3)
|
||||
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
|
||||
|
||||
# cache cleanup -> reset values to 0
|
||||
latency_cache_keys = [(key, 0) for key in latency_keys]
|
||||
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
||||
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=combined_metrics_cache_keys
|
||||
)
|
||||
|
||||
# send alert
|
||||
await self.send_alert(message=message, level="Low", alert_type="daily_reports")
|
||||
|
||||
return True
|
||||
|
||||
async def response_taking_too_long(
|
||||
self,
|
||||
|
@ -326,6 +498,11 @@ class SlackAlerting:
|
|||
# in that case we fallback to the api base set in the request metadata
|
||||
_metadata = request_data["metadata"]
|
||||
_api_base = _metadata.get("api_base", "")
|
||||
|
||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||
request_info=request_info, metadata=_metadata
|
||||
)
|
||||
|
||||
if _api_base is None:
|
||||
_api_base = ""
|
||||
request_info += f"\nAPI Base: `{_api_base}`"
|
||||
|
@ -335,14 +512,13 @@ class SlackAlerting:
|
|||
)
|
||||
|
||||
if "langfuse" in litellm.success_callback:
|
||||
request_info = self._add_langfuse_trace_id_to_alert(
|
||||
request_info=request_info,
|
||||
langfuse_url = self._add_langfuse_trace_id_to_alert(
|
||||
request_data=request_data,
|
||||
type="hanging_request",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
if langfuse_url is not None:
|
||||
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
||||
|
||||
# add deployment latencies to alert
|
||||
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
||||
metadata=request_data.get("metadata", {})
|
||||
|
@ -475,6 +651,53 @@ class SlackAlerting:
|
|||
|
||||
return
|
||||
|
||||
async def model_added_alert(self, model_name: str, litellm_model_name: str):
|
||||
model_info = litellm.model_cost.get(litellm_model_name, {})
|
||||
model_info_str = ""
|
||||
for k, v in model_info.items():
|
||||
if k == "input_cost_per_token" or k == "output_cost_per_token":
|
||||
# when converting to string it should not be 1.63e-06
|
||||
v = "{:.8f}".format(v)
|
||||
|
||||
model_info_str += f"{k}: {v}\n"
|
||||
|
||||
message = f"""
|
||||
*🚅 New Model Added*
|
||||
Model Name: `{model_name}`
|
||||
|
||||
Usage OpenAI Python SDK:
|
||||
```
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="your_api_key",
|
||||
base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="{model_name}", # model to send to the proxy
|
||||
messages = [
|
||||
{{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}}
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
Model Info:
|
||||
```
|
||||
{model_info_str}
|
||||
```
|
||||
"""
|
||||
|
||||
await self.send_alert(
|
||||
message=message, level="Low", alert_type="new_model_added"
|
||||
)
|
||||
pass
|
||||
|
||||
async def model_removed_alert(self, model_name: str):
|
||||
pass
|
||||
|
||||
async def send_alert(
|
||||
self,
|
||||
message: str,
|
||||
|
@ -485,7 +708,11 @@ class SlackAlerting:
|
|||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"daily_reports",
|
||||
"new_model_added",
|
||||
"cooldown_deployment",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
||||
|
@ -510,9 +737,16 @@ class SlackAlerting:
|
|||
# Get the current timestamp
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
||||
formatted_message = (
|
||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||
)
|
||||
if alert_type == "daily_reports" or alert_type == "new_model_added":
|
||||
formatted_message = message
|
||||
else:
|
||||
formatted_message = (
|
||||
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
||||
if _proxy_base_url is not None:
|
||||
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
||||
|
||||
|
@ -522,6 +756,8 @@ class SlackAlerting:
|
|||
and alert_type in self.alert_to_webhook_url
|
||||
):
|
||||
slack_webhook_url = self.alert_to_webhook_url[alert_type]
|
||||
elif self.default_webhook_url is not None:
|
||||
slack_webhook_url = self.default_webhook_url
|
||||
else:
|
||||
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
||||
|
||||
|
@ -539,3 +775,113 @@ class SlackAlerting:
|
|||
pass
|
||||
else:
|
||||
print("Error sending slack alert. Error=", response.text) # noqa
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Log deployment latency"""
|
||||
if "daily_reports" in self.alert_types:
|
||||
model_id = (
|
||||
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||
)
|
||||
response_s: timedelta = end_time - start_time
|
||||
|
||||
final_value = response_s
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, litellm.ModelResponse):
|
||||
completion_tokens = response_obj.usage.completion_tokens
|
||||
final_value = float(response_s.total_seconds() / completion_tokens)
|
||||
|
||||
await self.async_update_daily_reports(
|
||||
DeploymentMetrics(
|
||||
id=model_id,
|
||||
failed_request=False,
|
||||
latency_per_output_token=final_value,
|
||||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
)
|
||||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Log failure + deployment latency"""
|
||||
if "daily_reports" in self.alert_types:
|
||||
model_id = (
|
||||
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
|
||||
)
|
||||
await self.async_update_daily_reports(
|
||||
DeploymentMetrics(
|
||||
id=model_id,
|
||||
failed_request=True,
|
||||
latency_per_output_token=None,
|
||||
updated_at=litellm.utils.get_utc_datetime(),
|
||||
)
|
||||
)
|
||||
if "llm_exceptions" in self.alert_types:
|
||||
original_exception = kwargs.get("exception", None)
|
||||
|
||||
await self.send_alert(
|
||||
message="LLM API Failure - " + str(original_exception),
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
)
|
||||
|
||||
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True -> report sent
|
||||
- False -> report not sent
|
||||
"""
|
||||
report_sent_bool = False
|
||||
|
||||
report_sent = await self.internal_usage_cache.async_get_cache(
|
||||
key=SlackAlertingCacheKeys.report_sent_key.value
|
||||
) # None | datetime
|
||||
|
||||
current_time = litellm.utils.get_utc_datetime()
|
||||
|
||||
if report_sent is None:
|
||||
_current_time = current_time.isoformat()
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||
value=_current_time,
|
||||
)
|
||||
else:
|
||||
# check if current time - interval >= time last sent
|
||||
delta = current_time - timedelta(
|
||||
seconds=self.alerting_args.daily_report_frequency
|
||||
)
|
||||
|
||||
if isinstance(report_sent, str):
|
||||
report_sent = dt.fromisoformat(report_sent)
|
||||
|
||||
if delta >= report_sent:
|
||||
# Sneak in the reporting logic here
|
||||
await self.send_daily_reports(router=llm_router)
|
||||
# Also, don't forget to update the report_sent time after sending the report!
|
||||
_current_time = current_time.isoformat()
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=SlackAlertingCacheKeys.report_sent_key.value,
|
||||
value=_current_time,
|
||||
)
|
||||
report_sent_bool = True
|
||||
|
||||
return report_sent_bool
|
||||
|
||||
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
||||
"""
|
||||
If 'daily_reports' enabled
|
||||
|
||||
Ping redis cache every 5 minutes to check if we should send the report
|
||||
|
||||
If yes -> call send_daily_report()
|
||||
"""
|
||||
if llm_router is None or self.alert_types is None:
|
||||
return
|
||||
|
||||
if "daily_reports" in self.alert_types:
|
||||
while True:
|
||||
await self._run_scheduler_helper(llm_router=llm_router)
|
||||
interval = random.randint(
|
||||
self.alerting_args.report_check_interval - 3,
|
||||
self.alerting_args.report_check_interval + 3,
|
||||
) # shuffle to prevent collisions
|
||||
await asyncio.sleep(interval)
|
||||
return
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import os, types, traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time, httpx
|
||||
import requests # type: ignore
|
||||
import time, httpx # type: ignore
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Choices, Message
|
||||
import litellm
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class AlephAlphaError(Exception):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
|
@ -9,7 +9,7 @@ import litellm
|
|||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
|
@ -84,6 +84,51 @@ class AnthropicConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
if (
|
||||
value == "\n"
|
||||
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
value = [value]
|
||||
elif isinstance(value, list):
|
||||
new_v = []
|
||||
for v in value:
|
||||
if (
|
||||
v == "\n"
|
||||
) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
new_v.append(v)
|
||||
if len(new_v) > 0:
|
||||
value = new_v
|
||||
else:
|
||||
continue
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(api_key, user_headers):
|
||||
|
@ -139,11 +184,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
elif len(completion_response["content"]) == 0:
|
||||
raise AnthropicError(
|
||||
message="No content in response",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Union, Any
|
||||
from typing import Optional, Union, Any, Literal
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import (
|
||||
|
@ -8,14 +8,16 @@ from litellm.utils import (
|
|||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
)
|
||||
from typing import Callable, Optional, BinaryIO
|
||||
from litellm import OpenAIConfig
|
||||
import litellm, json
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
|
||||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
import uuid
|
||||
import os
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -126,6 +128,51 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
|||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant = os.getenv("AZURE_TENANT_ID", None)
|
||||
|
||||
if azure_client_id is None or azure_tenant is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
req_token = httpx.post(
|
||||
f"https://login.microsoftonline.com/{azure_tenant}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
possible_azure_ad_token = req_token.json().get("access_token", None)
|
||||
|
||||
if possible_azure_ad_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token not returned"
|
||||
)
|
||||
|
||||
return possible_azure_ad_token
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -137,6 +184,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
headers["api-key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
|
@ -151,7 +200,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
|
@ -189,6 +238,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if acompletion is True:
|
||||
|
@ -276,6 +328,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -351,6 +405,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
|
@ -422,6 +478,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
|
@ -478,6 +536,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
|
@ -599,6 +659,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
## LOGGING
|
||||
|
@ -755,6 +817,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
|
@ -833,6 +897,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if max_retries is not None:
|
||||
|
@ -952,6 +1018,81 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
raise e
|
||||
|
||||
def get_headers(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
client_session = litellm.client_session or httpx.Client(
|
||||
transport=CustomHTTPTransport(), # handle dall-e-2 calls
|
||||
)
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
client = AzureOpenAI(
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
model = None
|
||||
# cloudflare ai gateway, needs model=None
|
||||
else:
|
||||
client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
|
||||
# only run this check if it's not cloudflare ai gateway
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
completion = None
|
||||
|
||||
if messages is None:
|
||||
messages = [{"role": "user", "content": "Hey"}]
|
||||
try:
|
||||
completion = client.chat.completions.with_raw_response.create(
|
||||
model=model, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
response = {}
|
||||
|
||||
if completion is None or not hasattr(completion, "headers"):
|
||||
raise Exception("invalid completion response")
|
||||
|
||||
if (
|
||||
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||
): # not provided for dall-e requests
|
||||
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ms-region", None) is not None:
|
||||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||
|
||||
return response
|
||||
|
||||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
|
@ -963,7 +1104,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
messages: Optional[list] = None,
|
||||
input: Optional[list] = None,
|
||||
prompt: Optional[str] = None,
|
||||
):
|
||||
) -> dict:
|
||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||
)
|
||||
|
@ -1040,4 +1181,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
|
||||
if completion.headers.get("x-ms-region", None) is not None:
|
||||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||
|
||||
return response
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional, Union, Any
|
||||
import types, requests
|
||||
import types, requests # type: ignore
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
|
|
@ -4,7 +4,13 @@ from enum import Enum
|
|||
import time, uuid
|
||||
from typing import Callable, Optional, Any, Union, List
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
get_secret,
|
||||
Usage,
|
||||
ImageResponse,
|
||||
map_finish_reason,
|
||||
)
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -157,6 +163,7 @@ class AmazonAnthropicClaude3Config:
|
|||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
|
@ -524,6 +531,17 @@ class AmazonStabilityConfig:
|
|||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
|
@ -533,12 +551,13 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
|
@ -549,6 +568,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
|
@ -564,6 +584,7 @@ def init_bedrock_client(
|
|||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
|
@ -592,10 +613,48 @@ def init_bedrock_client(
|
|||
|
||||
import boto3
|
||||
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config(
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config()
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
if aws_web_identity_token is not None and aws_role_name is not None and aws_session_name is not None:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts"
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
|
@ -647,6 +706,10 @@ def init_bedrock_client(
|
|||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register(
|
||||
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
@ -710,6 +773,7 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
_is_function_call = False
|
||||
|
@ -725,6 +789,7 @@ def completion(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop("aws_bedrock_client", None)
|
||||
|
@ -739,6 +804,8 @@ def completion(
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
@ -1043,7 +1110,9 @@ def completion(
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
model_response["finish_reason"] = map_finish_reason(
|
||||
response_body["stop_reason"]
|
||||
)
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||
completion_tokens=response_body["usage"]["output_tokens"],
|
||||
|
@ -1194,7 +1263,7 @@ def _embedding_func_single(
|
|||
"input_type", "search_document"
|
||||
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params} # type: ignore
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
body = json.dumps(data).encode("utf-8") # type: ignore
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
|
@ -1258,6 +1327,7 @@ def embedding(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1265,6 +1335,7 @@ def embedding(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
@ -1347,6 +1418,7 @@ def image_generation(
|
|||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
|
@ -1354,6 +1426,7 @@ def image_generation(
|
|||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
timeout=timeout,
|
||||
|
@ -1386,7 +1459,7 @@ def image_generation(
|
|||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body},
|
||||
body={body}, # type: ignore
|
||||
modelId={modelId},
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time, traceback
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time, traceback
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import cohere_message_pt
|
||||
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@ import httpx, requests
|
|||
from .base import BaseLLM
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable, Dict, List, Any
|
||||
from typing import Callable, Dict, List, Any, Literal
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
import enum
|
||||
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
hf_task_list = [
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
hf_tasks = Literal[
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
|
||||
class HuggingfaceConfig:
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
|
||||
hf_task: Optional[hf_tasks] = (
|
||||
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||
)
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"echo",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
return optional_params
|
||||
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
|
@ -162,16 +227,18 @@ def read_tgi_conv_models():
|
|||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model):
|
||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
if model.split("/")[0] in hf_task_list:
|
||||
return model.split("/")[0] # type: ignore
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
return "text-generation-inference"
|
||||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
return "text-generation"
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
|
||||
|
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
|
|||
self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
task: hf_tasks,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
|
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
elif task == "text-classification":
|
||||
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||
completion_response
|
||||
)
|
||||
else:
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
|
@ -322,9 +393,9 @@ class Huggingface(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
acompletion: bool = False,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
|
|||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
task = get_hf_task_for_model(model)
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||
)
|
||||
|
||||
print_verbose(f"{model}, {task}")
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
@ -399,10 +476,11 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": (
|
||||
"stream": ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and isinstance(optional_params["stream"], bool)
|
||||
and optional_params["stream"] == True # type: ignore
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
@ -432,14 +510,15 @@ class Huggingface(BaseLLM):
|
|||
inference_params.pop("return_full_text")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": (
|
||||
}
|
||||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
data["stream"] = ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
)
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -530,10 +609,10 @@ class Huggingface(BaseLLM):
|
|||
isinstance(completion_response, dict)
|
||||
and "error" in completion_response
|
||||
):
|
||||
print_verbose(f"completion error: {completion_response['error']}")
|
||||
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||
print_verbose(f"response.status_code: {response.status_code}")
|
||||
raise HuggingfaceError(
|
||||
message=completion_response["error"],
|
||||
message=completion_response["error"], # type: ignore
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return self.convert_to_model_response_object(
|
||||
|
@ -562,7 +641,7 @@ class Huggingface(BaseLLM):
|
|||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
task: hf_tasks,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time, traceback
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import requests, types, time
|
||||
from itertools import chain
|
||||
import requests, types, time # type: ignore
|
||||
import json, uuid
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import httpx, aiohttp, asyncio
|
||||
import httpx, aiohttp, asyncio # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
|
@ -212,25 +213,31 @@ def get_ollama_response(
|
|||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if optional_params.get("format", "") == "json":
|
||||
if data.get("format", "") == "json":
|
||||
function_call = json.loads(response_json["response"])
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
completion_tokens = response_json.get(
|
||||
"eval_count", len(response_json.get("message", dict()).get("content", ""))
|
||||
)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
@ -255,8 +262,37 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = next(streamwrapper)
|
||||
response_content = "".join(
|
||||
chunk.choices[0].delta.content
|
||||
for chunk in chain([first_chunk], streamwrapper)
|
||||
if chunk.choices[0].delta.content
|
||||
)
|
||||
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response = first_chunk
|
||||
model_response["choices"][0]["delta"] = delta
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
yield model_response
|
||||
else:
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -278,8 +314,40 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = await anext(streamwrapper)
|
||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||
response_content = first_chunk_content + "".join(
|
||||
[
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content
|
||||
]
|
||||
)
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response = first_chunk
|
||||
model_response["choices"][0]["delta"] = delta
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
yield model_response
|
||||
else:
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
@ -317,12 +385,16 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"function": {
|
||||
"name": function_call["name"],
|
||||
"arguments": json.dumps(function_call["arguments"]),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json[
|
||||
"response"
|
||||
|
@ -330,7 +402,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data["model"]
|
||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
||||
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
|
||||
completion_tokens = response_json.get(
|
||||
"eval_count",
|
||||
len(response_json.get("message", dict()).get("content", "")),
|
||||
)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
@ -417,3 +492,25 @@ async def ollama_aembeddings(
|
|||
"total_tokens": total_input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
|
||||
def ollama_embeddings(
|
||||
api_base: str,
|
||||
model: str,
|
||||
prompts: list,
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
return asyncio.run(
|
||||
ollama_aembeddings(
|
||||
api_base,
|
||||
model,
|
||||
prompts,
|
||||
optional_params,
|
||||
logging_obj,
|
||||
model_response,
|
||||
encoding,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from itertools import chain
|
||||
import requests, types, time
|
||||
import json, uuid
|
||||
import traceback
|
||||
|
@ -297,6 +298,7 @@ def get_ollama_response(
|
|||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"] = response_json["message"]
|
||||
model_response["created"] = int(time.time())
|
||||
|
@ -335,8 +337,35 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
|||
custom_llm_provider="ollama_chat",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = next(streamwrapper)
|
||||
response_content = "".join(
|
||||
chunk.choices[0].delta.content
|
||||
for chunk in chain([first_chunk], streamwrapper)
|
||||
if chunk.choices[0].delta.content
|
||||
)
|
||||
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response = first_chunk
|
||||
model_response["choices"][0]["delta"] = delta
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
yield model_response
|
||||
else:
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -366,8 +395,36 @@ async def ollama_async_streaming(
|
|||
custom_llm_provider="ollama_chat",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = await anext(streamwrapper)
|
||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||
response_content = first_chunk_content + "".join(
|
||||
[
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content]
|
||||
)
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response = first_chunk
|
||||
model_response["choices"][0]["delta"] = delta
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
yield model_response
|
||||
else:
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
@ -425,6 +482,7 @@ async def ollama_acompletion(
|
|||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
model_response["choices"][0]["finish_reason"] = "tool_calls"
|
||||
else:
|
||||
model_response["choices"][0]["message"] = response_json["message"]
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
|
|
@ -1,4 +1,13 @@
|
|||
from typing import Optional, Union, Any, BinaryIO
|
||||
from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Literal,
|
||||
Iterable,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from pydantic import BaseModel
|
||||
import types, time, json, traceback
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
|
@ -13,10 +22,10 @@ from litellm.utils import (
|
|||
TextCompletionResponse,
|
||||
)
|
||||
from typing import Callable, Optional
|
||||
import aiohttp, requests
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from ..types.llms.openai import *
|
||||
|
||||
|
||||
class OpenAIError(Exception):
|
||||
|
@ -246,7 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
|
@ -271,9 +280,12 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
if not isinstance(timeout, float):
|
||||
if not isinstance(timeout, float) and not isinstance(
|
||||
timeout, httpx.Timeout
|
||||
):
|
||||
raise OpenAIError(
|
||||
status_code=422, message=f"Timeout needs to be a float"
|
||||
status_code=422,
|
||||
message=f"Timeout needs to be a float or httpx.Timeout",
|
||||
)
|
||||
|
||||
if custom_llm_provider != "openai":
|
||||
|
@ -425,7 +437,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
|
@ -480,7 +492,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -518,13 +530,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -567,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
return streamwrapper
|
||||
except (
|
||||
|
@ -1191,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
for chunk in streamwrapper:
|
||||
|
@ -1229,7 +1244,228 @@ class OpenAITextCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
|
||||
class OpenAIAssistantsAPI(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> OpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
openai_client = OpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
### ASSISTANTS ###
|
||||
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> SyncCursorPage[Assistant]:
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.assistants.list()
|
||||
|
||||
return response
|
||||
|
||||
### MESSAGES ###
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: MessageData,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> OpenAIMessage:
|
||||
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
|
||||
thread_id, **message_data
|
||||
)
|
||||
|
||||
response_obj: Optional[OpenAIMessage] = None
|
||||
if getattr(thread_message, "status", None) is None:
|
||||
thread_message.status = "completed"
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
else:
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
return response_obj
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
||||
return response
|
||||
|
||||
### THREADS ###
|
||||
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
) -> Thread:
|
||||
"""
|
||||
Here's an example:
|
||||
```
|
||||
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
|
||||
|
||||
# create thread
|
||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||
openai_api.create_thread(messages=[message])
|
||||
```
|
||||
"""
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if messages is not None:
|
||||
data["messages"] = messages # type: ignore
|
||||
if metadata is not None:
|
||||
data["metadata"] = metadata # type: ignore
|
||||
|
||||
message_thread = openai_client.beta.threads.create(**data) # type: ignore
|
||||
|
||||
return Thread(**message_thread.dict())
|
||||
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> Thread:
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
||||
return Thread(**response.dict())
|
||||
|
||||
def delete_thread(self):
|
||||
pass
|
||||
|
||||
### RUNS ###
|
||||
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> Run:
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
|
|
520
litellm/llms/predibase.py
Normal file
520
litellm/llms/predibase.py
Normal file
|
@ -0,0 +1,520 @@
|
|||
# What is this?
|
||||
## Controller file for Predibase Integration - https://predibase.com/
|
||||
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List, Literal, Union
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
map_finish_reason,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class PredibaseError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request is not None:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||
)
|
||||
if response is not None:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class PredibaseConfig:
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
|
||||
"""
|
||||
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: bool = True # enables returning logprobs + best of
|
||||
max_new_tokens: int = (
|
||||
256 # openai default - requests hang if max_new_tokens not given
|
||||
)
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = (
|
||||
False # by default don't return the input as part of the output
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
truncate: Optional[int] = None
|
||||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
|
||||
|
||||
class PredibaseChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
if user_headers is not None and isinstance(user_headers, dict):
|
||||
headers = {**headers, **user_headers}
|
||||
return headers
|
||||
|
||||
def output_parser(self, generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
||||
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||
"""
|
||||
chat_template_tokens = [
|
||||
"<|assistant|>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
]
|
||||
for token in chat_template_tokens:
|
||||
if generated_text.strip().startswith(token):
|
||||
generated_text = generated_text.replace(token, "", 1)
|
||||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise PredibaseError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not isinstance(completion_response, dict)
|
||||
or "generated_text" not in completion_response
|
||||
):
|
||||
raise PredibaseError(
|
||||
status_code=422,
|
||||
message=f"response is not in expected format - {completion_response}",
|
||||
)
|
||||
|
||||
if len(completion_response["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||
completion_response["generated_text"]
|
||||
)
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if (
|
||||
"details" in completion_response
|
||||
and "tokens" in completion_response["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0][
|
||||
"message"
|
||||
]._logprob = (
|
||||
sum_logprob # [TODO] move this to using the actual logprobs
|
||||
)
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
if (
|
||||
"details" in completion_response
|
||||
and "best_of_sequences" in completion_response["details"]
|
||||
):
|
||||
choices_list = []
|
||||
for idx, item in enumerate(
|
||||
completion_response["details"]["best_of_sequences"]
|
||||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
content=self.output_parser(item["generated_text"]),
|
||||
logprobs=sum_logprob,
|
||||
)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
) ##[TODO] use a model-specific tokenizer here
|
||||
except:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||
if output_text is not None and len(output_text) > 0:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
) ##[TODO] use a model-specific tokenizer
|
||||
except:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
else:
|
||||
completion_tokens = 0
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage # type: ignore
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
tenant_id: str,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
elif api_base:
|
||||
base_url = api_base
|
||||
elif "PREDIBASE_API_BASE" in os.environ:
|
||||
base_url = os.getenv("PREDIBASE_API_BASE", "")
|
||||
|
||||
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||
|
||||
if optional_params.get("stream", False) == True:
|
||||
completion_url += "/generate_stream"
|
||||
else:
|
||||
completion_url += "/generate"
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.PredibaseConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
) # type: ignore
|
||||
|
||||
### SYNC STREAMING
|
||||
if stream == True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
_response = CustomStreamWrapper(
|
||||
response.iter_lines(),
|
||||
model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return _response
|
||||
### SYNC COMPLETION
|
||||
else:
|
||||
response = requests.post(
|
||||
url=completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj, # type: ignore
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> ModelResponse:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> CustomStreamWrapper:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
pass
|
|
@ -12,6 +12,16 @@ from typing import (
|
|||
Sequence,
|
||||
)
|
||||
import litellm
|
||||
from litellm.types.completion import (
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
from litellm.types.llms.anthropic import *
|
||||
import uuid
|
||||
|
||||
|
||||
def default_pt(messages):
|
||||
|
@ -22,6 +32,41 @@ def prompt_injection_detection_default_pt():
|
|||
return """Detect if a prompt is safe to run. Return 'UNSAFE' if not."""
|
||||
|
||||
|
||||
def map_system_message_pt(messages: list) -> list:
|
||||
"""
|
||||
Convert 'system' message to 'user' message if provider doesn't support 'system' role.
|
||||
|
||||
Enabled via `completion(...,supports_system_message=False)`
|
||||
|
||||
If next message is a user message or assistant message -> merge system prompt into it
|
||||
|
||||
if next message is system -> append a user message instead of the system message
|
||||
"""
|
||||
|
||||
new_messages = []
|
||||
for i, m in enumerate(messages):
|
||||
if m["role"] == "system":
|
||||
if i < len(messages) - 1: # Not the last message
|
||||
next_m = messages[i + 1]
|
||||
next_role = next_m["role"]
|
||||
if (
|
||||
next_role == "user" or next_role == "assistant"
|
||||
): # Next message is a user or assistant message
|
||||
# Merge system prompt into the next message
|
||||
next_m["content"] = m["content"] + " " + next_m["content"]
|
||||
elif next_role == "system": # Next message is a system message
|
||||
# Append a user message instead of the system message
|
||||
new_message = {"role": "user", "content": m["content"]}
|
||||
new_messages.append(new_message)
|
||||
else: # Last message
|
||||
new_message = {"role": "user", "content": m["content"]}
|
||||
new_messages.append(new_message)
|
||||
else: # Not a system message
|
||||
new_messages.append(m)
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
# alpaca prompt template - for models like mythomax, etc.
|
||||
def alpaca_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
|
@ -805,6 +850,13 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
|||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
},
|
||||
|
||||
OpenAI message with a function call result looks like:
|
||||
{
|
||||
"role": "function",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
}
|
||||
"""
|
||||
|
||||
"""
|
||||
|
@ -821,18 +873,42 @@ def convert_to_anthropic_tool_result(message: dict) -> dict:
|
|||
]
|
||||
}
|
||||
"""
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
content = message.get("content")
|
||||
if message["role"] == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
content = message.get("content")
|
||||
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
anthropic_tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
anthropic_tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
return anthropic_tool_result
|
||||
elif message["role"] == "function":
|
||||
content = message.get("content")
|
||||
anthropic_tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": str(uuid.uuid4()),
|
||||
"content": content,
|
||||
}
|
||||
return anthropic_tool_result
|
||||
return {}
|
||||
|
||||
return anthropic_tool_result
|
||||
|
||||
def convert_function_to_anthropic_tool_invoke(function_call):
|
||||
try:
|
||||
anthropic_tool_invoke = [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": get_attribute_or_key(function_call, "name"),
|
||||
"input": json.loads(get_attribute_or_key(function_call, "arguments")),
|
||||
}
|
||||
]
|
||||
return anthropic_tool_invoke
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
||||
|
@ -895,7 +971,7 @@ def convert_to_anthropic_tool_invoke(tool_calls: list) -> list:
|
|||
def anthropic_messages_pt(messages: list):
|
||||
"""
|
||||
format messages for anthropic
|
||||
1. Anthropic supports roles like "user" and "assistant", (here litellm translates system-> assistant)
|
||||
1. Anthropic supports roles like "user" and "assistant" (system prompt sent separately)
|
||||
2. The first message always needs to be of role "user"
|
||||
3. Each message must alternate between "user" and "assistant" (this is not addressed as now by litellm)
|
||||
4. final assistant content cannot end with trailing whitespace (anthropic raises an error otherwise)
|
||||
|
@ -903,12 +979,14 @@ def anthropic_messages_pt(messages: list):
|
|||
6. Ensure we only accept role, content. (message.name is not supported)
|
||||
"""
|
||||
# add role=tool support to allow function call result/error submission
|
||||
user_message_types = {"user", "tool"}
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
|
||||
new_messages = []
|
||||
new_messages: list = []
|
||||
msg_i = 0
|
||||
tool_use_param = False
|
||||
while msg_i < len(messages):
|
||||
user_content = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
|
@ -924,7 +1002,10 @@ def anthropic_messages_pt(messages: list):
|
|||
)
|
||||
elif m.get("type", "") == "text":
|
||||
user_content.append({"type": "text", "text": m["text"]})
|
||||
elif messages[msg_i]["role"] == "tool":
|
||||
elif (
|
||||
messages[msg_i]["role"] == "tool"
|
||||
or messages[msg_i]["role"] == "function"
|
||||
):
|
||||
# OpenAI's tool message content will always be a string
|
||||
user_content.append(convert_to_anthropic_tool_result(messages[msg_i]))
|
||||
else:
|
||||
|
@ -953,11 +1034,24 @@ def anthropic_messages_pt(messages: list):
|
|||
convert_to_anthropic_tool_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
|
||||
if messages[msg_i].get("function_call"):
|
||||
assistant_content.extend(
|
||||
convert_function_to_anthropic_tool_invoke(
|
||||
messages[msg_i]["function_call"]
|
||||
)
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
if not new_messages or new_messages[0]["role"] != "user":
|
||||
if litellm.modify_params:
|
||||
new_messages.insert(
|
||||
|
@ -969,11 +1063,14 @@ def anthropic_messages_pt(messages: list):
|
|||
)
|
||||
|
||||
if new_messages[-1]["role"] == "assistant":
|
||||
for content in new_messages[-1]["content"]:
|
||||
if isinstance(content, dict) and content["type"] == "text":
|
||||
content["text"] = content[
|
||||
"text"
|
||||
].rstrip() # no trailing whitespace for final assistant message
|
||||
if isinstance(new_messages[-1]["content"], str):
|
||||
new_messages[-1]["content"] = new_messages[-1]["content"].rstrip()
|
||||
elif isinstance(new_messages[-1]["content"], list):
|
||||
for content in new_messages[-1]["content"]:
|
||||
if isinstance(content, dict) and content["type"] == "text":
|
||||
content["text"] = content[
|
||||
"text"
|
||||
].rstrip() # no trailing whitespace for final assistant message
|
||||
|
||||
return new_messages
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os, types
|
||||
import json
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
import os, types, traceback
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Any
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
|
@ -295,7 +295,7 @@ def completion(
|
|||
EndpointName={model},
|
||||
InferenceComponentName={model_id},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
Body={data}, # type: ignore
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
|
@ -321,7 +321,7 @@ def completion(
|
|||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
Body={data}, # type: ignore
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
|
@ -688,7 +688,7 @@ def embedding(
|
|||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
Body={data}, # type: ignore
|
||||
CustomAttributes="accept_eula=true",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
|
|
|
@ -6,11 +6,11 @@ Reference: https://docs.together.ai/docs/openai-api-compatibility
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
|
119
litellm/llms/triton.py
Normal file
119
litellm/llms/triton.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class TritonError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.anthropic.com/v1/messages", # using anthropic api base since httpx requires a url
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class TritonChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_base: str,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
|
||||
async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
response = await async_handler.post(url=api_base, data=json.dumps(data))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TritonError(status_code=response.status_code, message=response.text)
|
||||
|
||||
_text_response = response.text
|
||||
|
||||
logging_obj.post_call(original_response=_text_response)
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
_outputs = _json_response["outputs"]
|
||||
_output_data = _outputs[0]["data"]
|
||||
_embedding_output = {
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": _output_data,
|
||||
}
|
||||
|
||||
model_response.model = _json_response.get("model_name", "None")
|
||||
model_response.data = [_embedding_output]
|
||||
|
||||
return model_response
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
||||
curl_string = f"curl {api_base} -X POST -H 'Content-Type: application/json' -d '{data_for_triton}'"
|
||||
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": curl_string,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise Exception(
|
||||
"Only async embedding supported for triton, please use litellm.aembedding() for now"
|
||||
)
|
|
@ -1,12 +1,12 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect
|
||||
import httpx, inspect # type: ignore
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -419,6 +419,7 @@ def completion(
|
|||
from google.protobuf.struct_pb2 import Value # type: ignore
|
||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||
import google.auth # type: ignore
|
||||
import proto # type: ignore
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
print_verbose(
|
||||
|
@ -605,9 +606,21 @@ def completion(
|
|||
):
|
||||
function_call = response.candidates[0].content.parts[0].function_call
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
args_str = json.dumps(args_dict)
|
||||
|
||||
# Check if it's a RepeatedComposite instance
|
||||
for key, val in function_call.args.items():
|
||||
if isinstance(
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||
):
|
||||
# If so, convert to list
|
||||
args_dict[key] = [v for v in val]
|
||||
else:
|
||||
args_dict[key] = val
|
||||
|
||||
try:
|
||||
args_str = json.dumps(args_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=422, message=str(e))
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
|
@ -810,6 +823,8 @@ def completion(
|
|||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
if isinstance(e, VertexAIError):
|
||||
raise e
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy
|
||||
import requests, copy # type: ignore
|
||||
import time, uuid
|
||||
from typing import Callable, Optional, List
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
|
@ -17,7 +17,7 @@ from .prompt_templates.factory import (
|
|||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
)
|
||||
import httpx
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
import time, httpx
|
||||
import requests # type: ignore
|
||||
import time, httpx # type: ignore
|
||||
from typing import Callable, Any
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
|
|
@ -3,8 +3,8 @@ import json, types, time # noqa: E401
|
|||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Any, Union, List
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
|
||||
|
|
180
litellm/main.py
180
litellm/main.py
|
@ -14,7 +14,6 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
|||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
|
@ -34,9 +33,12 @@ from litellm.utils import (
|
|||
async_mock_completion_streaming_obj,
|
||||
convert_to_model_response_object,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
create_tokenizer,
|
||||
Usage,
|
||||
get_optional_params_embeddings,
|
||||
get_optional_params_image_gen,
|
||||
supports_httpx_timeout,
|
||||
)
|
||||
from .llms import (
|
||||
anthropic_text,
|
||||
|
@ -44,6 +46,7 @@ from .llms import (
|
|||
ai21,
|
||||
sagemaker,
|
||||
bedrock,
|
||||
triton,
|
||||
huggingface_restapi,
|
||||
replicate,
|
||||
aleph_alpha,
|
||||
|
@ -71,10 +74,13 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
map_system_message_pt,
|
||||
)
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -106,6 +112,8 @@ anthropic_text_completions = AnthropicTextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -184,6 +192,7 @@ async def acompletion(
|
|||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
|
@ -203,6 +212,7 @@ async def acompletion(
|
|||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
extra_headers: Optional[dict] = None,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -220,6 +230,7 @@ async def acompletion(
|
|||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||
n (int, optional): The number of completions to generate (default is 1).
|
||||
stream (bool, optional): If True, return a streaming response (default is False).
|
||||
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
|
||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
|
@ -257,6 +268,7 @@ async def acompletion(
|
|||
"top_p": top_p,
|
||||
"n": n,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"stop": stop,
|
||||
"max_tokens": max_tokens,
|
||||
"presence_penalty": presence_penalty,
|
||||
|
@ -301,6 +313,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
|
@ -309,6 +322,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "predibase"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -448,11 +462,12 @@ def completion(
|
|||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
|
@ -492,6 +507,7 @@ def completion(
|
|||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||
n (int, optional): The number of completions to generate (default is 1).
|
||||
stream (bool, optional): If True, return a streaming response (default is False).
|
||||
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
|
||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
|
@ -551,6 +567,7 @@ def completion(
|
|||
eos_token = kwargs.get("eos_token", None)
|
||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
supports_system_message = kwargs.get("supports_system_message", None)
|
||||
### TEXT COMPLETION CALLS ###
|
||||
text_completion = kwargs.get("text_completion", False)
|
||||
atext_completion = kwargs.get("atext_completion", False)
|
||||
|
@ -568,6 +585,7 @@ def completion(
|
|||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
|
@ -616,6 +634,7 @@ def completion(
|
|||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"retry_policy",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
|
@ -641,16 +660,30 @@ def completion(
|
|||
"no-log",
|
||||
"base_model",
|
||||
"stream_timeout",
|
||||
"supports_system_message",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
]
|
||||
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
if timeout is None:
|
||||
timeout = (
|
||||
kwargs.get("request_timeout", None) or 600
|
||||
) # set timeout for 10 minutes by default
|
||||
timeout = float(timeout)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
|
||||
try:
|
||||
if base_url is not None:
|
||||
api_base = base_url
|
||||
|
@ -745,6 +778,13 @@ def completion(
|
|||
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
|
||||
if (
|
||||
supports_system_message is not None
|
||||
and isinstance(supports_system_message, bool)
|
||||
and supports_system_message == False
|
||||
):
|
||||
messages = map_system_message_pt(messages=messages)
|
||||
model_api_key = get_api_key(
|
||||
llm_provider=custom_llm_provider, dynamic_api_key=api_key
|
||||
) # get the api key from the environment if required for the model
|
||||
|
@ -759,6 +799,7 @@ def completion(
|
|||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
|
@ -871,7 +912,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
|
@ -958,6 +999,7 @@ def completion(
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openai"
|
||||
|
@ -1012,7 +1054,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
|
@ -1097,7 +1139,7 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -1471,7 +1513,7 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -1564,7 +1606,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
@ -1749,6 +1791,52 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "predibase":
|
||||
tenant_id = (
|
||||
optional_params.pop("tenant_id", None)
|
||||
or optional_params.pop("predibase_tenant_id", None)
|
||||
or litellm.predibase_tenant_id
|
||||
or get_secret("PREDIBASE_TENANT_ID")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
optional_params.pop("api_base", None)
|
||||
or optional_params.pop("base_url", None)
|
||||
or litellm.api_base
|
||||
or get_secret("PREDIBASE_API_BASE")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.predibase_key
|
||||
or get_secret("PREDIBASE_API_KEY")
|
||||
)
|
||||
|
||||
_model_response = predibase_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
api_key=api_key,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "ai21":
|
||||
custom_llm_provider = "ai21"
|
||||
ai21_key = (
|
||||
|
@ -1844,6 +1932,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
@ -1891,7 +1980,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -2143,7 +2232,7 @@ def completion(
|
|||
"""
|
||||
assume input to custom LLM api bases follow this format:
|
||||
resp = requests.post(
|
||||
api_base,
|
||||
api_base,
|
||||
json={
|
||||
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
||||
'params': {
|
||||
|
@ -2272,7 +2361,7 @@ def batch_completion(
|
|||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
|
@ -2535,11 +2624,13 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "voyage"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
|
@ -2665,6 +2756,7 @@ def embedding(
|
|||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"retry_policy",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
|
@ -2688,6 +2780,8 @@ def embedding(
|
|||
"ttl",
|
||||
"cache",
|
||||
"no-log",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -2864,23 +2958,43 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret("VERTEX_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret("VERTEX_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
or get_secret("VERTEX_CREDENTIALS")
|
||||
)
|
||||
|
||||
response = vertex_ai.embedding(
|
||||
|
@ -2921,16 +3035,18 @@ def embedding(
|
|||
model=model, # type: ignore
|
||||
llm_provider="ollama", # type: ignore
|
||||
)
|
||||
if aembedding:
|
||||
response = ollama.ollama_aembeddings(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
ollama_embeddings_fn = (
|
||||
ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||
)
|
||||
response = ollama_embeddings_fn(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
model=model,
|
||||
|
@ -3059,11 +3175,13 @@ async def atext_completion(*args, **kwargs):
|
|||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3094,6 +3212,8 @@ async def atext_completion(*args, **kwargs):
|
|||
## TRANSLATE CHAT TO TEXT FORMAT ##
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
elif asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
|
||||
text_completion_response = TextCompletionResponse()
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
|
@ -3153,6 +3273,7 @@ def text_completion(
|
|||
Union[str, List[str]]
|
||||
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
||||
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
||||
stream_options: Optional[dict] = None,
|
||||
suffix: Optional[
|
||||
str
|
||||
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
||||
|
@ -3230,6 +3351,8 @@ def text_completion(
|
|||
optional_params["stop"] = stop
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stream_options is not None:
|
||||
optional_params["stream_options"] = stream_options
|
||||
if suffix is not None:
|
||||
optional_params["suffix"] = suffix
|
||||
if temperature is not None:
|
||||
|
@ -3340,7 +3463,9 @@ def text_completion(
|
|||
if kwargs.get("acompletion", False) == True:
|
||||
return response
|
||||
if stream == True or kwargs.get("stream", False) == True:
|
||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||
response = TextCompletionStreamWrapper(
|
||||
completion_stream=response, model=model, stream_options=stream_options
|
||||
)
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
|
@ -3534,6 +3659,7 @@ def image_generation(
|
|||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"retry_policy",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
|
@ -3554,6 +3680,8 @@ def image_generation(
|
|||
"caching_groups",
|
||||
"ttl",
|
||||
"cache",
|
||||
"region_name",
|
||||
"allowed_model_region",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
|
|
@ -338,6 +338,18 @@
|
|||
"output_cost_per_second": 0.0001,
|
||||
"litellm_provider": "azure"
|
||||
},
|
||||
"azure/gpt-4-turbo-2024-04-09": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"azure/gpt-4-0125-preview": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -727,6 +739,24 @@
|
|||
"litellm_provider": "mistral",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"deepseek-chat": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"deepseek-coder": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 16000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000014,
|
||||
"output_cost_per_token": 0.00000028,
|
||||
"litellm_provider": "deepseek",
|
||||
"mode": "chat"
|
||||
},
|
||||
"groq/llama2-70b-4096": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 4096,
|
||||
|
@ -813,6 +843,7 @@
|
|||
"litellm_provider": "anthropic",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 264
|
||||
},
|
||||
"claude-3-opus-20240229": {
|
||||
|
@ -824,6 +855,7 @@
|
|||
"litellm_provider": "anthropic",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 395
|
||||
},
|
||||
"claude-3-sonnet-20240229": {
|
||||
|
@ -835,6 +867,7 @@
|
|||
"litellm_provider": "anthropic",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 159
|
||||
},
|
||||
"text-bison": {
|
||||
|
@ -1045,8 +1078,8 @@
|
|||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"input_cost_per_token": 0.000000625,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
|
@ -1057,8 +1090,8 @@
|
|||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"input_cost_per_token": 0.000000625,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
|
@ -1069,8 +1102,8 @@
|
|||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 8192,
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"input_cost_per_token": 0.000000625,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"litellm_provider": "vertex_ai-language-models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
|
@ -1142,7 +1175,8 @@
|
|||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"vertex_ai/claude-3-haiku@20240307": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -1152,7 +1186,8 @@
|
|||
"output_cost_per_token": 0.00000125,
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"vertex_ai/claude-3-opus@20240229": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -1162,7 +1197,8 @@
|
|||
"output_cost_per_token": 0.0000075,
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"textembedding-gecko": {
|
||||
"max_tokens": 3072,
|
||||
|
@ -1581,6 +1617,7 @@
|
|||
"litellm_provider": "openrouter",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 395
|
||||
},
|
||||
"openrouter/google/palm-2-chat-bison": {
|
||||
|
@ -1813,6 +1850,15 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"amazon.titan-embed-text-v2:0": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 8192,
|
||||
"output_vector_size": 1024,
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"mistral.mistral-7b-instruct-v0:2": {
|
||||
"max_tokens": 8191,
|
||||
"max_input_tokens": 32000,
|
||||
|
@ -1929,7 +1975,8 @@
|
|||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -1939,7 +1986,8 @@
|
|||
"output_cost_per_token": 0.00000125,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-3-opus-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -1949,7 +1997,8 @@
|
|||
"output_cost_per_token": 0.000075,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"anthropic.claude-v1": {
|
||||
"max_tokens": 8191,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{87421:function(n,e,t){Promise.resolve().then(t.t.bind(t,99646,23)),Promise.resolve().then(t.t.bind(t,63385,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_c23dc8', '__Inter_Fallback_c23dc8'",fontStyle:"normal"},className:"__className_c23dc8"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=87421)}),_N_E=n.O()}]);
|
|
@ -1 +0,0 @@
|
|||
(self.webpackChunk_N_E=self.webpackChunk_N_E||[]).push([[185],{93553:function(n,e,t){Promise.resolve().then(t.t.bind(t,63385,23)),Promise.resolve().then(t.t.bind(t,99646,23))},63385:function(){},99646:function(n){n.exports={style:{fontFamily:"'__Inter_12bbc4', '__Inter_Fallback_12bbc4'",fontStyle:"normal"},className:"__className_12bbc4"}}},function(n){n.O(0,[971,69,744],function(){return n(n.s=93553)}),_N_E=n.O()}]);
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue