forked from phoenix/litellm-mirror
Merge branch 'BerriAI:main' into feature_allow_claude_prefill
This commit is contained in:
commit
53e5e1df07
42 changed files with 1395 additions and 490 deletions
|
@ -134,11 +134,15 @@ jobs:
|
|||
- run:
|
||||
name: Trigger Github Action for new Docker Container
|
||||
command: |
|
||||
echo "Install TOML package."
|
||||
python3 -m pip install toml
|
||||
VERSION=$(python3 -c "import toml; print(toml.load('pyproject.toml')['tool']['poetry']['version'])")
|
||||
echo "LiteLLM Version ${VERSION}"
|
||||
curl -X POST \
|
||||
-H "Accept: application/vnd.github.v3+json" \
|
||||
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
||||
"https://api.github.com/repos/BerriAI/litellm/actions/workflows/ghcr_deploy.yml/dispatches" \
|
||||
-d '{"ref":"main"}'
|
||||
-d "{\"ref\":\"main\", \"inputs\":{\"tag\":\"${VERSION}\"}}"
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
|
|
43
.github/workflows/ghcr_deploy.yml
vendored
43
.github/workflows/ghcr_deploy.yml
vendored
|
@ -1,12 +1,10 @@
|
|||
#
|
||||
name: Build, Publish LiteLLM Docker Image
|
||||
# this workflow is triggered by an API call when there is a new PyPI release of LiteLLM
|
||||
name: Build, Publish LiteLLM Docker Image. New Release
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag version you want to build"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Defines two custom environment variables for the workflow. Used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
|
@ -46,7 +44,7 @@ jobs:
|
|||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||
tags: ${{ steps.meta.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }} # if a tag is provided, use that, otherwise use the release tag, and if neither is available, use 'latest'
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-and-push-image-alpine:
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -76,5 +74,38 @@ jobs:
|
|||
context: .
|
||||
dockerfile: Dockerfile.alpine
|
||||
push: true
|
||||
tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || github.event.release.tag_name || 'latest' }}
|
||||
tags: ${{ steps.meta-alpine.outputs.tags }}-${{ github.event.inputs.tag || 'latest' }}
|
||||
labels: ${{ steps.meta-alpine.outputs.labels }}
|
||||
release:
|
||||
name: "New LiteLLM Release"
|
||||
|
||||
runs-on: "ubuntu-latest"
|
||||
|
||||
steps:
|
||||
- name: Display version
|
||||
run: echo "Current version is ${{ github.event.inputs.tag }}"
|
||||
- name: "Set Release Tag"
|
||||
run: echo "RELEASE_TAG=${{ github.event.inputs.tag }}" >> $GITHUB_ENV
|
||||
- name: Display release tag
|
||||
run: echo "RELEASE_TAG is $RELEASE_TAG"
|
||||
- name: "Create release"
|
||||
uses: "actions/github-script@v6"
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
try {
|
||||
const response = await github.rest.repos.createRelease({
|
||||
draft: false,
|
||||
generate_release_notes: true,
|
||||
name: process.env.RELEASE_TAG,
|
||||
owner: context.repo.owner,
|
||||
prerelease: false,
|
||||
repo: context.repo.repo,
|
||||
tag_name: process.env.RELEASE_TAG,
|
||||
});
|
||||
|
||||
core.exportVariable('RELEASE_ID', response.data.id);
|
||||
core.exportVariable('RELEASE_UPLOAD_URL', response.data.upload_url);
|
||||
} catch (error) {
|
||||
core.setFailed(error.message);
|
||||
}
|
||||
|
|
31
.github/workflows/read_pyproject_version.yml
vendored
Normal file
31
.github/workflows/read_pyproject_version.yml
vendored
Normal file
|
@ -0,0 +1,31 @@
|
|||
name: Read Version from pyproject.toml
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # Change this to the default branch of your repository
|
||||
|
||||
jobs:
|
||||
read-version:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.8 # Adjust the Python version as needed
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install toml
|
||||
|
||||
- name: Read version from pyproject.toml
|
||||
id: read-version
|
||||
run: |
|
||||
version=$(python -c 'import toml; print(toml.load("pyproject.toml")["tool"]["commitizen"]["version"])')
|
||||
printf "LITELLM_VERSION=%s" "$version" >> $GITHUB_ENV
|
||||
|
||||
- name: Display version
|
||||
run: echo "Current version is $LITELLM_VERSION"
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -31,3 +31,4 @@ proxy_server_config_@.yaml
|
|||
.gitignore
|
||||
proxy_server_config_2.yaml
|
||||
litellm/proxy/secret_managers/credentials.json
|
||||
hosted_config.yaml
|
||||
|
|
13
Dockerfile
13
Dockerfile
|
@ -3,7 +3,6 @@ ARG LITELLM_BUILD_IMAGE=python:3.9
|
|||
|
||||
# Runtime image
|
||||
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
|
||||
|
||||
# Builder stage
|
||||
FROM $LITELLM_BUILD_IMAGE as builder
|
||||
|
||||
|
@ -35,8 +34,12 @@ RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt
|
|||
|
||||
# Runtime stage
|
||||
FROM $LITELLM_RUNTIME_IMAGE as runtime
|
||||
ARG with_database
|
||||
|
||||
WORKDIR /app
|
||||
# Copy the current directory contents into the container at /app
|
||||
COPY . .
|
||||
RUN ls -la /app
|
||||
|
||||
# Copy the built wheel from the builder stage to the runtime stage; assumes only one wheel file is present
|
||||
COPY --from=builder /app/dist/*.whl .
|
||||
|
@ -45,6 +48,14 @@ COPY --from=builder /wheels/ /wheels/
|
|||
# Install the built wheel using pip; again using a wildcard if it's the only file
|
||||
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels
|
||||
|
||||
# Check if the with_database argument is set to 'true'
|
||||
RUN echo "Value of with_database is: ${with_database}"
|
||||
# If true, execute the following instructions
|
||||
RUN if [ "$with_database" = "true" ]; then \
|
||||
prisma generate; \
|
||||
chmod +x /app/retry_push.sh; \
|
||||
/app/retry_push.sh; \
|
||||
fi
|
||||
|
||||
EXPOSE 4000/tcp
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@
|
|||
LITELLM_MASTER_KEY="sk-1234"
|
||||
|
||||
############
|
||||
# Database - You can change these to any PostgreSQL database that has logical replication enabled.
|
||||
# Database - You can change these to any PostgreSQL database.
|
||||
############
|
||||
|
||||
# LITELLM_DATABASE_URL="your-postgres-db-url"
|
||||
DATABASE_URL="your-postgres-db-url"
|
||||
|
||||
|
||||
############
|
||||
|
|
|
@ -396,6 +396,47 @@ response = completion(
|
|||
)
|
||||
```
|
||||
|
||||
## OpenAI Proxy
|
||||
|
||||
Track spend across multiple projects/people
|
||||
|
||||
The proxy provides:
|
||||
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
|
||||
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
|
||||
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
|
||||
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
|
||||
|
||||
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
|
||||
|
||||
### Quick Start Proxy - CLI
|
||||
|
||||
```shell
|
||||
pip install litellm[proxy]
|
||||
```
|
||||
|
||||
#### Step 1: Start litellm proxy
|
||||
```shell
|
||||
$ litellm --model huggingface/bigcode/starcoder
|
||||
|
||||
#INFO: Proxy running on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
#### Step 2: Make ChatCompletions Request to Proxy
|
||||
```python
|
||||
import openai # openai v1.0.0+
|
||||
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
## More details
|
||||
* [exception mapping](./exception_mapping.md)
|
||||
* [retries + model fallbacks for completion()](./completion/reliable_completions.md)
|
||||
|
|
|
@ -161,7 +161,7 @@ litellm_settings:
|
|||
The proxy support 3 cache-controls:
|
||||
|
||||
- `ttl`: Will cache the response for the user-defined amount of time (in seconds).
|
||||
- `s-max-age`: Will only accept cached responses that are within user-defined range (in seconds).
|
||||
- `s-maxage`: Will only accept cached responses that are within user-defined range (in seconds).
|
||||
- `no-cache`: Will not return a cached response, but instead call the actual endpoint.
|
||||
|
||||
[Let us know if you need more](https://github.com/BerriAI/litellm/issues/1218)
|
||||
|
@ -237,7 +237,7 @@ chat_completion = client.chat.completions.create(
|
|||
],
|
||||
model="gpt-3.5-turbo",
|
||||
cache={
|
||||
"s-max-age": 600 # only get responses cached within last 10 minutes
|
||||
"s-maxage": 600 # only get responses cached within last 10 minutes
|
||||
}
|
||||
)
|
||||
```
|
||||
|
|
43
docs/my-website/docs/proxy/rules.md
Normal file
43
docs/my-website/docs/proxy/rules.md
Normal file
|
@ -0,0 +1,43 @@
|
|||
# Post-Call Rules
|
||||
|
||||
Use this to fail a request based on the output of an llm api call.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Create a file (e.g. post_call_rules.py)
|
||||
|
||||
```python
|
||||
def my_custom_rule(input): # receives the model response
|
||||
if len(input) < 5: # trigger fallback if the model response is too short
|
||||
return False
|
||||
return True
|
||||
```
|
||||
|
||||
### Step 2. Point it to your proxy
|
||||
|
||||
```python
|
||||
litellm_settings:
|
||||
post_call_rules: post_call_rules.my_custom_rule
|
||||
num_retries: 3
|
||||
```
|
||||
|
||||
### Step 3. Start + test your proxy
|
||||
|
||||
```bash
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:8000/v1/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "deepseek-coder",
|
||||
"messages": [{"role":"user","content":"What llm are you?"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 10,
|
||||
}'
|
||||
```
|
||||
---
|
||||
|
||||
This will now check if a response is > len 5, and if it fails, it'll retry a call 3 times before failing.
|
|
@ -112,6 +112,7 @@ const sidebars = {
|
|||
"proxy/reliability",
|
||||
"proxy/health",
|
||||
"proxy/call_hooks",
|
||||
"proxy/rules",
|
||||
"proxy/caching",
|
||||
"proxy/alerting",
|
||||
"proxy/logging",
|
||||
|
|
|
@ -375,6 +375,45 @@ response = completion(
|
|||
|
||||
Need a dedicated key? Email us @ krrish@berri.ai
|
||||
|
||||
## OpenAI Proxy
|
||||
|
||||
Track spend across multiple projects/people
|
||||
|
||||
The proxy provides:
|
||||
1. [Hooks for auth](https://docs.litellm.ai/docs/proxy/virtual_keys#custom-auth)
|
||||
2. [Hooks for logging](https://docs.litellm.ai/docs/proxy/logging#step-1---create-your-custom-litellm-callback-class)
|
||||
3. [Cost tracking](https://docs.litellm.ai/docs/proxy/virtual_keys#tracking-spend)
|
||||
4. [Rate Limiting](https://docs.litellm.ai/docs/proxy/users#set-rate-limits)
|
||||
|
||||
### 📖 Proxy Endpoints - [Swagger Docs](https://litellm-api.up.railway.app/)
|
||||
|
||||
### Quick Start Proxy - CLI
|
||||
|
||||
```shell
|
||||
pip install litellm[proxy]
|
||||
```
|
||||
|
||||
#### Step 1: Start litellm proxy
|
||||
```shell
|
||||
$ litellm --model huggingface/bigcode/starcoder
|
||||
|
||||
#INFO: Proxy running on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
#### Step 2: Make ChatCompletions Request to Proxy
|
||||
```python
|
||||
import openai # openai v1.0.0+
|
||||
client = openai.OpenAI(api_key="anything",base_url="http://0.0.0.0:8000") # set proxy to base_url
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
}
|
||||
])
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
## More details
|
||||
* [exception mapping](./exception_mapping.md)
|
||||
|
|
|
@ -338,7 +338,8 @@ baseten_models: List = [
|
|||
] # FALCON 7B # WizardLM # Mosaic ML
|
||||
|
||||
|
||||
# used for token counting
|
||||
# used for Cost Tracking & Token counting
|
||||
# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
|
||||
# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
|
||||
azure_llms = {
|
||||
"gpt-35-turbo": "azure/gpt-35-turbo",
|
||||
|
@ -346,6 +347,10 @@ azure_llms = {
|
|||
"gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
|
||||
}
|
||||
|
||||
azure_embedding_models = {
|
||||
"ada": "azure/ada",
|
||||
}
|
||||
|
||||
petals_models = [
|
||||
"petals-team/StableBeluga2",
|
||||
]
|
||||
|
|
|
@ -11,6 +11,7 @@ import litellm
|
|||
import time, logging
|
||||
import json, traceback, ast, hashlib
|
||||
from typing import Optional, Literal, List, Union, Any
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -472,7 +473,10 @@ class Cache:
|
|||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
max_age = kwargs.get("cache", {}).get("s-max-age", float("inf"))
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key)
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
|
@ -529,7 +533,7 @@ class Cache:
|
|||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, litellm.ModelResponse):
|
||||
if isinstance(result, OpenAIObject):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## Get Cache-Controls ##
|
||||
|
|
|
@ -724,6 +724,21 @@ class AzureChatCompletion(BaseLLM):
|
|||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||
)
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
client = AsyncAzureOpenAI(
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
http_client=client_session,
|
||||
)
|
||||
model = None
|
||||
# cloudflare ai gateway, needs model=None
|
||||
else:
|
||||
client = AsyncAzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=api_base,
|
||||
|
@ -732,6 +747,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
http_client=client_session,
|
||||
)
|
||||
|
||||
# only run this check if it's not cloudflare ai gateway
|
||||
if model is None and mode != "image_generation":
|
||||
raise Exception("model is not set")
|
||||
|
||||
|
|
0
litellm/llms/custom_httpx/bedrock_async.py
Normal file
0
litellm/llms/custom_httpx/bedrock_async.py
Normal file
|
@ -14,12 +14,18 @@ model_list:
|
|||
- model_name: BEDROCK_GROUP
|
||||
litellm_params:
|
||||
model: bedrock/cohere.command-text-v14
|
||||
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
||||
- model_name: openai-gpt-3.5
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
mode: chat
|
||||
- model_name: azure-cloudflare
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-07-01-preview"
|
||||
- model_name: azure-embedding-model
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
|
|
|
@ -307,9 +307,8 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
async def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
if (
|
||||
database_url is not None and prisma_client is None
|
||||
): # don't re-initialize prisma client after initial init
|
||||
|
@ -321,6 +320,8 @@ def prisma_setup(database_url: Optional[str]):
|
|||
print_verbose(
|
||||
f"Error when initializing prisma, Ensure you run pip install prisma {str(e)}"
|
||||
)
|
||||
if prisma_client is not None and prisma_client.db.is_connected() == False:
|
||||
await prisma_client.connect()
|
||||
|
||||
|
||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||
|
@ -502,21 +503,110 @@ async def _run_background_health_check():
|
|||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
class ProxyConfig:
|
||||
"""
|
||||
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||
global prisma_client, user_config_file_path
|
||||
|
||||
file_path = config_file_path or user_config_file_path
|
||||
if config_file_path is not None:
|
||||
user_config_file_path = config_file_path
|
||||
with open(config_file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
# Load existing config
|
||||
## Yaml
|
||||
if file_path is not None:
|
||||
if os.path.exists(f"{file_path}"):
|
||||
with open(f"{file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Path to config does not exist, Current working directory: {os.getcwd()}, 'os.path.exists({config_file_path})' returned False"
|
||||
raise Exception(f"File not found! - {file_path}")
|
||||
|
||||
## DB
|
||||
if (
|
||||
prisma_client is not None
|
||||
and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True
|
||||
):
|
||||
await prisma_setup(database_url=None) # in case it's not been connected yet
|
||||
_tasks = []
|
||||
keys = [
|
||||
"model_list",
|
||||
"general_settings",
|
||||
"router_settings",
|
||||
"litellm_settings",
|
||||
]
|
||||
for k in keys:
|
||||
response = prisma_client.get_generic_data(
|
||||
key="param_name", value=k, table_name="config"
|
||||
)
|
||||
_tasks.append(response)
|
||||
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
|
||||
return config
|
||||
|
||||
async def save_config(self, new_config: dict):
|
||||
global prisma_client, llm_router, user_config_file_path, llm_model_list, general_settings
|
||||
# Load existing config
|
||||
backup_config = await self.get_config()
|
||||
|
||||
# Save the updated config
|
||||
## YAML
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||
|
||||
# update Router - verifies if this is a valid config
|
||||
try:
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Exception while reading Config: {e}")
|
||||
traceback.print_exc()
|
||||
# Revert to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(status_code=400, detail="Invalid config passed in")
|
||||
|
||||
## DB - writes valid config to db
|
||||
"""
|
||||
- Do not write restricted params like 'api_key' to the database
|
||||
- if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
|
||||
"""
|
||||
if (
|
||||
prisma_client is not None
|
||||
and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True
|
||||
):
|
||||
### KEY REMOVAL ###
|
||||
models = new_config.get("model_list", [])
|
||||
for m in models:
|
||||
if m.get("litellm_params", {}).get("api_key", None) is not None:
|
||||
# pop the key
|
||||
api_key = m["litellm_params"].pop("api_key")
|
||||
# store in local env
|
||||
key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}"
|
||||
os.environ[key_name] = api_key
|
||||
# save the key name (not the value)
|
||||
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
|
||||
await prisma_client.insert_data(data=new_config, table_name="config")
|
||||
|
||||
async def load_config(
|
||||
self, router: Optional[litellm.Router], config_file_path: str
|
||||
):
|
||||
"""
|
||||
Load config values into proxy global state
|
||||
"""
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
|
||||
# Load existing config
|
||||
config = await self.get_config(config_file_path=config_file_path)
|
||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
|
@ -559,12 +649,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
|
||||
cache_params = {
|
||||
cache_params.update(
|
||||
{
|
||||
"type": cache_type,
|
||||
"host": cache_host,
|
||||
"port": cache_port,
|
||||
"password": cache_password,
|
||||
}
|
||||
)
|
||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||
print( # noqa
|
||||
f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}"
|
||||
|
@ -604,7 +696,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.success_callback.append(get_instance_fn(value=callback))
|
||||
litellm.success_callback.append(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
|
@ -618,7 +712,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
for callback in value:
|
||||
# user passed custom_callbacks.async_on_succes_logger. They need us to import a function
|
||||
if "." in callback:
|
||||
litellm.failure_callback.append(get_instance_fn(value=callback))
|
||||
litellm.failure_callback.append(
|
||||
get_instance_fn(value=callback)
|
||||
)
|
||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.failure_callback.append(callback)
|
||||
|
@ -665,7 +761,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
print_verbose(f"GOING INTO LITELLM.GET_SECRET!")
|
||||
database_url = litellm.get_secret(database_url)
|
||||
print_verbose(f"RETRIEVED DB URL: {database_url}")
|
||||
prisma_setup(database_url=database_url)
|
||||
await prisma_setup(database_url=database_url)
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### MASTER KEY ###
|
||||
|
@ -730,6 +826,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
return router, model_list, general_settings
|
||||
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
|
||||
|
||||
async def generate_key_helper_fn(
|
||||
duration: Optional[str],
|
||||
models: list,
|
||||
|
@ -797,6 +896,7 @@ async def generate_key_helper_fn(
|
|||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
}
|
||||
print_verbose("PrismaClient: Before Insert Data")
|
||||
new_verification_token = await prisma_client.insert_data(
|
||||
data=verification_token_data
|
||||
)
|
||||
|
@ -831,7 +931,7 @@ def save_worker_config(**data):
|
|||
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
||||
|
||||
|
||||
def initialize(
|
||||
async def initialize(
|
||||
model=None,
|
||||
alias=None,
|
||||
api_base=None,
|
||||
|
@ -849,7 +949,7 @@ def initialize(
|
|||
use_queue=False,
|
||||
config=None,
|
||||
):
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth
|
||||
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
|
||||
generate_feedback_box()
|
||||
user_model = model
|
||||
user_debug = debug
|
||||
|
@ -857,9 +957,11 @@ def initialize(
|
|||
litellm.set_verbose = True
|
||||
dynamic_config = {"general": {}, user_model: {}}
|
||||
if config:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=config
|
||||
)
|
||||
(
|
||||
llm_router,
|
||||
llm_model_list,
|
||||
general_settings,
|
||||
) = await proxy_config.load_config(router=llm_router, config_file_path=config)
|
||||
if headers: # model-specific param
|
||||
user_headers = headers
|
||||
dynamic_config[user_model]["headers"] = headers
|
||||
|
@ -988,7 +1090,7 @@ def parse_cache_control(cache_control):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings
|
||||
import json
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
|
@ -1000,12 +1102,11 @@ async def startup_event():
|
|||
print_verbose(f"worker_config: {worker_config}")
|
||||
# check if it's a valid file path
|
||||
if os.path.isfile(worker_config):
|
||||
initialize(config=worker_config)
|
||||
await initialize(**worker_config)
|
||||
else:
|
||||
# if not, assume it's a json string
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
initialize(**worker_config)
|
||||
|
||||
await initialize(**worker_config)
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if use_background_health_checks:
|
||||
|
@ -1013,10 +1114,6 @@ async def startup_event():
|
|||
_run_background_health_check()
|
||||
) # start the background health check coroutine.
|
||||
|
||||
print_verbose(f"prisma client - {prisma_client}")
|
||||
if prisma_client is not None:
|
||||
await prisma_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(
|
||||
|
@ -1220,7 +1317,7 @@ async def chat_completion(
|
|||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
try:
|
||||
data = {}
|
||||
body = await request.body()
|
||||
|
@ -1673,6 +1770,7 @@ async def generate_key_fn(
|
|||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
print_verbose("entered /key/generate")
|
||||
data_json = data.json() # type: ignore
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
|
@ -1825,7 +1923,7 @@ async def user_auth(request: Request):
|
|||
|
||||
### Check if user email in user table
|
||||
response = await prisma_client.get_generic_data(
|
||||
key="user_email", value=user_email, db="users"
|
||||
key="user_email", value=user_email, table_name="users"
|
||||
)
|
||||
### if so - generate a 24 hr key with that user id
|
||||
if response is not None:
|
||||
|
@ -1883,16 +1981,13 @@ async def user_update(request: Request):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(model_params: ModelParams):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
try:
|
||||
print_verbose(f"User config path: {user_config_file_path}")
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {"model_list": []}
|
||||
backup_config = copy.deepcopy(config)
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
print_verbose(f"User config path: {user_config_file_path}")
|
||||
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
|
@ -1907,22 +2002,8 @@ async def add_new_model(model_params: ModelParams):
|
|||
|
||||
print_verbose(f"updated model list: {config['model_list']}")
|
||||
|
||||
# Save the updated config
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
|
||||
# update Router
|
||||
try:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
# Rever to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(status_code=400, detail="Invalid Model passed in")
|
||||
|
||||
print_verbose(f"llm_model_list: {llm_model_list}")
|
||||
# Save new config
|
||||
await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Model added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -1949,13 +2030,10 @@ async def add_new_model(model_params: ModelParams):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def model_info_v1(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {"model_list": []} # handle base case
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
all_models = config["model_list"]
|
||||
for model in all_models:
|
||||
|
@ -1984,18 +2062,18 @@ async def model_info_v1(request: Request):
|
|||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_model(model_info: ModelInfoDelete):
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path
|
||||
global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
try:
|
||||
if not os.path.exists(user_config_file_path):
|
||||
raise HTTPException(status_code=404, detail="Config file does not exist.")
|
||||
|
||||
with open(user_config_file_path, "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
# If model_list is not in the config, nothing can be deleted
|
||||
if "model_list" not in config:
|
||||
if len(config.get("model_list", [])) == 0:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No model list available in the config."
|
||||
status_code=400, detail="No model list available in the config."
|
||||
)
|
||||
|
||||
# Check if the model with the specified model_id exists
|
||||
|
@ -2008,19 +2086,14 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
# If the model was not found, return an error
|
||||
if model_to_delete is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Model with given model_id not found."
|
||||
status_code=400, detail="Model with given model_id not found."
|
||||
)
|
||||
|
||||
# Remove model from the list and save the updated config
|
||||
config["model_list"].remove(model_to_delete)
|
||||
with open(user_config_file_path, "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
|
||||
# Update Router
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
|
||||
# Save updated config
|
||||
config = await proxy_config.save_config(new_config=config)
|
||||
return {"message": "Model deleted successfully"}
|
||||
|
||||
except HTTPException as e:
|
||||
|
@ -2200,14 +2273,11 @@ async def update_config(config_info: ConfigYAML):
|
|||
|
||||
Currently supports modifying General Settings + LiteLLM settings
|
||||
"""
|
||||
global llm_router, llm_model_list, general_settings
|
||||
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj
|
||||
try:
|
||||
# Load existing config
|
||||
if os.path.exists(f"{user_config_file_path}"):
|
||||
with open(f"{user_config_file_path}", "r") as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
else:
|
||||
config = {}
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
backup_config = copy.deepcopy(config)
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
|
||||
|
@ -2240,20 +2310,13 @@ async def update_config(config_info: ConfigYAML):
|
|||
}
|
||||
|
||||
# Save the updated config
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
await proxy_config.save_config(new_config=config)
|
||||
|
||||
# update Router
|
||||
try:
|
||||
llm_router, llm_model_list, general_settings = load_router_config(
|
||||
router=llm_router, config_file_path=user_config_file_path
|
||||
)
|
||||
except Exception as e:
|
||||
# Rever to old config instead
|
||||
with open(f"{user_config_file_path}", "w") as config_file:
|
||||
yaml.dump(backup_config, config_file, default_flow_style=False)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid config passed in. Errror - {str(e)}"
|
||||
# Test new connections
|
||||
## Slack
|
||||
if "slack" in config.get("general_settings", {}).get("alerting", []):
|
||||
await proxy_logging_obj.alerting_handler(
|
||||
message="This is a test", level="Low"
|
||||
)
|
||||
return {"message": "Config updated successfully"}
|
||||
except HTTPException as e:
|
||||
|
@ -2263,6 +2326,21 @@ async def update_config(config_info: ConfigYAML):
|
|||
raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/config/get",
|
||||
tags=["config.yaml"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_config():
|
||||
"""
|
||||
Master key only.
|
||||
|
||||
Returns the config. Mainly used for testing.
|
||||
"""
|
||||
global proxy_config
|
||||
return await proxy_config.get_config()
|
||||
|
||||
|
||||
@router.get("/config/yaml", tags=["config.yaml"])
|
||||
async def config_yaml_endpoint(config_info: ConfigYAML):
|
||||
"""
|
||||
|
@ -2351,6 +2429,28 @@ async def health_endpoint(
|
|||
}
|
||||
|
||||
|
||||
@router.get("/health/readiness", tags=["health"])
|
||||
async def health_readiness():
|
||||
"""
|
||||
Unprotected endpoint for checking if worker can receive requests
|
||||
"""
|
||||
global prisma_client
|
||||
if prisma_client is not None: # if db passed in, check if it's connected
|
||||
if prisma_client.db.is_connected() == True:
|
||||
return {"status": "healthy"}
|
||||
else:
|
||||
return {"status": "healthy"}
|
||||
raise HTTPException(status_code=503, detail="Service Unhealthy")
|
||||
|
||||
|
||||
@router.get("/health/liveliness", tags=["health"])
|
||||
async def health_liveliness():
|
||||
"""
|
||||
Unprotected endpoint for checking if worker is alive
|
||||
"""
|
||||
return "I'm alive!"
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def home(request: Request):
|
||||
return "LiteLLM: RUNNING"
|
||||
|
|
|
@ -26,3 +26,8 @@ model LiteLLM_VerificationToken {
|
|||
max_parallel_requests Int?
|
||||
metadata Json @default("{}")
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
param_value Json?
|
||||
}
|
|
@ -250,12 +250,18 @@ def on_backoff(details):
|
|||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
### Check if prisma client can be imported (setup done in Docker build)
|
||||
try:
|
||||
from prisma import Client # type: ignore
|
||||
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
self.db = Client() # Client to connect to Prisma db
|
||||
except: # if not - go through normal setup process
|
||||
print_verbose(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
self.connected = False
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
# Save the current working directory
|
||||
original_dir = os.getcwd()
|
||||
|
@ -301,20 +307,24 @@ class PrismaClient:
|
|||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
db: Literal["users", "keys"],
|
||||
table_name: Literal["users", "keys", "config"],
|
||||
):
|
||||
"""
|
||||
Generic implementation of get data
|
||||
"""
|
||||
try:
|
||||
if db == "users":
|
||||
if table_name == "users":
|
||||
response = await self.db.litellm_usertable.find_first(
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif db == "keys":
|
||||
elif table_name == "keys":
|
||||
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
elif table_name == "config":
|
||||
response = await self.db.litellm_config.find_first( # type: ignore
|
||||
where={key: value} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
|
@ -336,15 +346,19 @@ class PrismaClient:
|
|||
user_id: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
print_verbose("PrismaClient: get_data")
|
||||
|
||||
response = None
|
||||
if token is not None:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
print_verbose("PrismaClient: find_unique")
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_token}
|
||||
)
|
||||
print_verbose(f"PrismaClient: response={response}")
|
||||
if response:
|
||||
# Token exists, now check expiration.
|
||||
if response.expires is not None and expires is not None:
|
||||
|
@ -372,6 +386,10 @@ class PrismaClient:
|
|||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
|
@ -385,17 +403,23 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def insert_data(self, data: dict):
|
||||
async def insert_data(
|
||||
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
|
||||
):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
try:
|
||||
if table_name == "user+key":
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
db_data = self.jsonify_object(data=data)
|
||||
db_data["token"] = hashed_token
|
||||
max_budget = db_data.pop("max_budget", None)
|
||||
user_email = db_data.pop("user_email", None)
|
||||
print_verbose(
|
||||
"PrismaClient: Before upsert into litellm_verificationtoken"
|
||||
)
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": hashed_token,
|
||||
|
@ -418,7 +442,32 @@ class PrismaClient:
|
|||
},
|
||||
)
|
||||
return new_verification_token
|
||||
elif table_name == "config":
|
||||
"""
|
||||
For each param,
|
||||
get the existing table values
|
||||
|
||||
Add the new values
|
||||
|
||||
Update DB
|
||||
"""
|
||||
tasks = []
|
||||
for k, v in data.items():
|
||||
updated_data = v
|
||||
updated_data = json.dumps(updated_data)
|
||||
updated_table_row = self.db.litellm_config.upsert(
|
||||
where={"param_name": k},
|
||||
data={
|
||||
"create": {"param_name": k, "param_value": updated_data},
|
||||
"update": {"param_value": updated_data},
|
||||
},
|
||||
)
|
||||
|
||||
tasks.append(updated_table_row)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
)
|
||||
|
@ -505,11 +554,7 @@ class PrismaClient:
|
|||
)
|
||||
async def connect(self):
|
||||
try:
|
||||
if self.connected == False:
|
||||
await self.db.connect()
|
||||
self.connected = True
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
|
|
@ -773,6 +773,10 @@ class Router:
|
|||
)
|
||||
original_exception = e
|
||||
try:
|
||||
if (
|
||||
hasattr(e, "status_code") and e.status_code == 400
|
||||
): # don't retry a malformed request
|
||||
raise e
|
||||
self.print_verbose(f"Trying to fallback b/w models")
|
||||
if (
|
||||
isinstance(e, litellm.ContextWindowExceededError)
|
||||
|
@ -846,7 +850,7 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||
if (
|
||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||
and context_window_fallbacks is None
|
||||
|
@ -864,12 +868,12 @@ class Router:
|
|||
min_timeout=self.retry_after,
|
||||
)
|
||||
await asyncio.sleep(timeout)
|
||||
elif (
|
||||
hasattr(original_exception, "status_code")
|
||||
and hasattr(original_exception, "response")
|
||||
and litellm._should_retry(status_code=original_exception.status_code)
|
||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||
status_code=original_exception.status_code
|
||||
):
|
||||
if hasattr(original_exception, "response") and hasattr(
|
||||
original_exception.response, "headers"
|
||||
):
|
||||
if hasattr(original_exception.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(
|
||||
remaining_retries=num_retries,
|
||||
max_retries=num_retries,
|
||||
|
@ -1326,6 +1330,7 @@ class Router:
|
|||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
|
|
|
@ -138,14 +138,15 @@ def test_async_completion_cloudflare():
|
|||
response = await litellm.acompletion(
|
||||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||
messages=[{"content": "what llm are you", "role": "user"}],
|
||||
max_tokens=50,
|
||||
max_tokens=5,
|
||||
num_retries=3,
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
response = asyncio.run(test())
|
||||
text_response = response["choices"][0]["message"]["content"]
|
||||
assert len(text_response) > 5 # more than 5 chars in response
|
||||
assert len(text_response) > 1 # more than 1 chars in response
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -166,7 +167,7 @@ def test_get_cloudflare_response_streaming():
|
|||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
timeout=5,
|
||||
num_retries=3, # cloudflare ai workers is EXTREMELY UNSTABLE
|
||||
)
|
||||
print(type(response))
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_caching_with_cache_controls():
|
|||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 0}
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 10}
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 10}
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
|
@ -105,7 +105,7 @@ def test_caching_with_cache_controls():
|
|||
model="gpt-3.5-turbo", messages=messages, cache={"ttl": 5}
|
||||
)
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-max-age": 5}
|
||||
model="gpt-3.5-turbo", messages=messages, cache={"s-maxage": 5}
|
||||
)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
|
@ -167,6 +167,8 @@ small text
|
|||
def test_embedding_caching():
|
||||
import time
|
||||
|
||||
# litellm.set_verbose = True
|
||||
|
||||
litellm.cache = Cache()
|
||||
text_to_embed = [embedding_large_text]
|
||||
start_time = time.time()
|
||||
|
@ -182,7 +184,7 @@ def test_embedding_caching():
|
|||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"embedding2: {embedding2}")
|
||||
# print(f"embedding2: {embedding2}")
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
litellm.cache = None
|
||||
|
@ -274,7 +276,7 @@ def test_redis_cache_completion():
|
|||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
print("test2 for caching")
|
||||
print("test2 for Redis Caching - non streaming")
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo", messages=messages, caching=True, max_tokens=20
|
||||
)
|
||||
|
@ -326,6 +328,10 @@ def test_redis_cache_completion():
|
|||
print(f"response4: {response4}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
|
||||
assert response1.id == response2.id
|
||||
assert response1.created == response2.created
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
|
||||
|
||||
# test_redis_cache_completion()
|
||||
|
||||
|
@ -395,7 +401,7 @@ def test_redis_cache_completion_stream():
|
|||
"""
|
||||
|
||||
|
||||
# test_redis_cache_completion_stream()
|
||||
test_redis_cache_completion_stream()
|
||||
|
||||
|
||||
def test_redis_cache_acompletion_stream():
|
||||
|
@ -529,6 +535,7 @@ def test_redis_cache_acompletion_stream_bedrock():
|
|||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
@ -537,7 +544,7 @@ def test_redis_cache_acompletion_stream_bedrock():
|
|||
raise e
|
||||
|
||||
|
||||
def test_s3_cache_acompletion_stream_bedrock():
|
||||
def test_s3_cache_acompletion_stream_azure():
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
|
@ -556,10 +563,13 @@ def test_s3_cache_acompletion_stream_bedrock():
|
|||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
response_1_created = ""
|
||||
response_2_created = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
nonlocal response_1_content, response_1_created
|
||||
response1 = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-v1",
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
|
@ -567,6 +577,7 @@ def test_s3_cache_acompletion_stream_bedrock():
|
|||
)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_created = chunk.created
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
|
||||
|
@ -575,9 +586,9 @@ def test_s3_cache_acompletion_stream_bedrock():
|
|||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
nonlocal response_2_content, response_2_created
|
||||
response2 = await litellm.acompletion(
|
||||
model="bedrock/anthropic.claude-v1",
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=1,
|
||||
|
@ -586,14 +597,23 @@ def test_s3_cache_acompletion_stream_bedrock():
|
|||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
response_2_created = chunk.created
|
||||
print(response_2_content)
|
||||
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
|
||||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
|
||||
# prioritizing getting a new deploy out - will look at this in the next deploy
|
||||
# print("response 1 created", response_1_created)
|
||||
# print("response 2 created", response_2_created)
|
||||
|
||||
# assert response_1_created == response_2_created
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
@ -602,7 +622,7 @@ def test_s3_cache_acompletion_stream_bedrock():
|
|||
raise e
|
||||
|
||||
|
||||
test_s3_cache_acompletion_stream_bedrock()
|
||||
# test_s3_cache_acompletion_stream_azure()
|
||||
|
||||
|
||||
# test_redis_cache_acompletion_stream_bedrock()
|
||||
|
|
|
@ -749,10 +749,14 @@ def test_completion_ollama_hosted():
|
|||
model="ollama/phi",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
num_retries=3,
|
||||
timeout=90,
|
||||
api_base="https://test-ollama-endpoint.onrender.com",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -1626,6 +1630,7 @@ def test_completion_anyscale_api():
|
|||
|
||||
|
||||
def test_azure_cloudflare_api():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
|
@ -1641,11 +1646,12 @@ def test_azure_cloudflare_api():
|
|||
)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
|
||||
# test_azure_cloudflare_api()
|
||||
test_azure_cloudflare_api()
|
||||
|
||||
|
||||
def test_completion_anyscale_2():
|
||||
|
@ -1931,6 +1937,7 @@ def test_completion_cloudflare():
|
|||
model="cloudflare/@cf/meta/llama-2-7b-chat-int8",
|
||||
messages=[{"content": "what llm are you", "role": "user"}],
|
||||
max_tokens=15,
|
||||
num_retries=3,
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
@ -1938,7 +1945,7 @@ def test_completion_cloudflare():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_cloudflare()
|
||||
test_completion_cloudflare()
|
||||
|
||||
|
||||
def test_moderation():
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_cost_azure_gpt_35():
|
|||
),
|
||||
)
|
||||
],
|
||||
model="azure/gpt-35-turbo", # azure always has model written like this
|
||||
model="gpt-35-turbo", # azure always has model written like this
|
||||
usage=Usage(prompt_tokens=21, completion_tokens=17, total_tokens=38),
|
||||
)
|
||||
|
||||
|
@ -125,3 +125,36 @@ def test_cost_azure_gpt_35():
|
|||
|
||||
|
||||
test_cost_azure_gpt_35()
|
||||
|
||||
|
||||
def test_cost_azure_embedding():
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
async def _test():
|
||||
response = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm", "gm"],
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
return response
|
||||
|
||||
response = asyncio.run(_test())
|
||||
|
||||
cost = litellm.completion_cost(completion_response=response)
|
||||
|
||||
print("Cost", cost)
|
||||
expected_cost = float("7e-07")
|
||||
assert cost == expected_cost
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Cost Calc failed for azure/gpt-3.5-turbo. Expected {expected_cost}, Calculated cost {cost}"
|
||||
)
|
||||
|
||||
|
||||
# test_cost_azure_embedding()
|
|
@ -0,0 +1,15 @@
|
|||
model_list:
|
||||
- model_name: azure-cloudflare
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True
|
||||
cache_params: # set cache params for s3
|
||||
type: s3
|
||||
s3_bucket_name: cache-bucket-litellm # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
|
@ -9,6 +9,11 @@ model_list:
|
|||
api_key: os.environ/AZURE_CANADA_API_KEY
|
||||
model: azure/gpt-35-turbo
|
||||
model_name: azure-model
|
||||
- litellm_params:
|
||||
api_base: https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
model: azure/chatgpt-v-2
|
||||
model_name: azure-cloudflare-model
|
||||
- litellm_params:
|
||||
api_base: https://openai-france-1234.openai.azure.com
|
||||
api_key: os.environ/AZURE_FRANCE_API_KEY
|
||||
|
|
|
@ -59,6 +59,7 @@ def test_openai_embedding():
|
|||
|
||||
def test_openai_azure_embedding_simple():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = embedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
|
@ -70,6 +71,10 @@ def test_openai_azure_embedding_simple():
|
|||
response_keys
|
||||
) # assert litellm response has expected keys from OpenAI embedding response
|
||||
|
||||
request_cost = litellm.completion_cost(completion_response=response)
|
||||
|
||||
print("Calculated request cost=", request_cost)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -260,15 +265,22 @@ def test_aembedding():
|
|||
input=["good morning from litellm", "this is another item"],
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
asyncio.run(embedding_call())
|
||||
response = asyncio.run(embedding_call())
|
||||
print("Before caclulating cost, response", response)
|
||||
|
||||
cost = litellm.completion_cost(completion_response=response)
|
||||
|
||||
print("COST=", cost)
|
||||
assert cost == float("1e-06")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_aembedding()
|
||||
test_aembedding()
|
||||
|
||||
|
||||
def test_aembedding_azure():
|
||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import pytest, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
@ -22,6 +22,7 @@ from litellm.proxy.proxy_server import (
|
|||
router,
|
||||
save_worker_config,
|
||||
initialize,
|
||||
ProxyConfig,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
|
||||
|
@ -36,7 +37,7 @@ def client():
|
|||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
app = FastAPI()
|
||||
initialize(config=config_fp)
|
||||
asyncio.run(initialize(config=config_fp))
|
||||
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
|
|
@ -23,6 +23,7 @@ from litellm.proxy.proxy_server import (
|
|||
router,
|
||||
save_worker_config,
|
||||
initialize,
|
||||
startup_event,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -39,8 +40,8 @@ python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
|||
def client():
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||
initialize(config=config_fp)
|
||||
app = FastAPI()
|
||||
asyncio.run(initialize(config=config_fp))
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from litellm.proxy.proxy_server import (
|
|||
def client():
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
||||
initialize(config=config_fp)
|
||||
asyncio.run(initialize(config=config_fp))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
@ -149,7 +149,7 @@ def test_chat_completion_exception_any_model(client):
|
|||
response=response
|
||||
)
|
||||
print("Exception raised=", openai_exception)
|
||||
assert isinstance(openai_exception, openai.NotFoundError)
|
||||
assert isinstance(openai_exception, openai.BadRequestError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
@ -170,7 +170,7 @@ def test_embedding_exception_any_model(client):
|
|||
response=response
|
||||
)
|
||||
print("Exception raised=", openai_exception)
|
||||
assert isinstance(openai_exception, openai.NotFoundError)
|
||||
assert isinstance(openai_exception, openai.BadRequestError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging
|
||||
import pytest, logging, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
@ -46,7 +46,7 @@ def client_no_auth():
|
|||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
initialize(config=config_fp, debug=True)
|
||||
asyncio.run(initialize(config=config_fp, debug=True))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import os, io
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging
|
||||
import pytest, logging, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
@ -45,7 +45,7 @@ def client_no_auth():
|
|||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
initialize(config=config_fp)
|
||||
asyncio.run(initialize(config=config_fp, debug=True))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
|
@ -280,44 +280,55 @@ def test_chat_completion_optional_params(client_no_auth):
|
|||
# test_chat_completion_optional_params()
|
||||
|
||||
# Test Reading config.yaml file
|
||||
from litellm.proxy.proxy_server import load_router_config
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
|
||||
def test_load_router_config():
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
print("testing reading config")
|
||||
# this is a basic config.yaml with only a model
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
result = load_router_config(
|
||||
proxy_config = ProxyConfig()
|
||||
result = asyncio.run(
|
||||
proxy_config.load_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/simple_config.yaml",
|
||||
)
|
||||
)
|
||||
print(result)
|
||||
assert len(result[1]) == 1
|
||||
|
||||
# this is a load balancing config yaml
|
||||
result = load_router_config(
|
||||
result = asyncio.run(
|
||||
proxy_config.load_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
||||
)
|
||||
)
|
||||
print(result)
|
||||
assert len(result[1]) == 2
|
||||
|
||||
# config with general settings - custom callbacks
|
||||
result = load_router_config(
|
||||
result = asyncio.run(
|
||||
proxy_config.load_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/azure_config.yaml",
|
||||
)
|
||||
)
|
||||
print(result)
|
||||
assert len(result[1]) == 2
|
||||
|
||||
# tests for litellm.cache set from config
|
||||
print("testing reading proxy config for cache")
|
||||
litellm.cache = None
|
||||
load_router_config(
|
||||
asyncio.run(
|
||||
proxy_config.load_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml",
|
||||
)
|
||||
)
|
||||
assert litellm.cache is not None
|
||||
assert "redis_client" in vars(
|
||||
litellm.cache.cache
|
||||
|
@ -329,11 +340,15 @@ def test_load_router_config():
|
|||
"aembedding",
|
||||
] # init with all call types
|
||||
|
||||
litellm.disable_cache()
|
||||
|
||||
print("testing reading proxy config for cache with params")
|
||||
load_router_config(
|
||||
asyncio.run(
|
||||
proxy_config.load_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml",
|
||||
)
|
||||
)
|
||||
assert litellm.cache is not None
|
||||
print(litellm.cache)
|
||||
print(litellm.cache.supported_call_types)
|
||||
|
|
|
@ -1,38 +1,103 @@
|
|||
# #### What this tests ####
|
||||
# # This tests using caching w/ litellm which requires SSL=True
|
||||
#### What this tests ####
|
||||
# This tests using caching w/ litellm which requires SSL=True
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# import sys, os
|
||||
# import time
|
||||
# import traceback
|
||||
# from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
# load_dotenv()
|
||||
# import os
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import pytest
|
||||
# import litellm
|
||||
# from litellm import embedding, completion
|
||||
# from litellm.caching import Cache
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
||||
# messages = [{"role": "user", "content": f"who is ishaan {time.time()}"}]
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # Set the desired logging level
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# @pytest.mark.skip(reason="local proxy test")
|
||||
# def test_caching_v2(): # test in memory cache
|
||||
# try:
|
||||
# response1 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
|
||||
# response2 = completion(model="openai/gpt-3.5-turbo", messages=messages, api_base="http://0.0.0.0:8000")
|
||||
# print(f"response1: {response1}")
|
||||
# print(f"response2: {response2}")
|
||||
# litellm.cache = None # disable cache
|
||||
# if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
# print(f"response1: {response1}")
|
||||
# print(f"response2: {response2}")
|
||||
# raise Exception()
|
||||
# except Exception as e:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import (
|
||||
router,
|
||||
save_worker_config,
|
||||
initialize,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
# test_caching_v2()
|
||||
# Your bearer token
|
||||
token = ""
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_no_auth():
|
||||
# Assuming litellm.proxy.proxy_server is an object
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
|
||||
cleanup_router_config_variables()
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_cloudflare_azure_with_cache_config.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
asyncio.run(initialize(config=config_fp, debug=True))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def generate_random_word(length=4):
|
||||
import string, random
|
||||
|
||||
letters = string.ascii_lowercase
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
def test_chat_completion(client_no_auth):
|
||||
global headers
|
||||
try:
|
||||
user_message = f"Write a poem about {generate_random_word()}"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "azure-cloudflare",
|
||||
"messages": messages,
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server with chat completions")
|
||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||
print(f"response - {response.text}")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = response.json()
|
||||
print(response)
|
||||
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
response1_id = response["id"]
|
||||
|
||||
print("\n content", content)
|
||||
|
||||
assert len(content) > 1
|
||||
|
||||
print("\nmaking 2nd request to proxy. Testing caching + non streaming")
|
||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||
print(f"response - {response.text}")
|
||||
assert response.status_code == 200
|
||||
|
||||
response = response.json()
|
||||
print(response)
|
||||
response2_id = response["id"]
|
||||
assert response1_id == response2_id
|
||||
litellm.disable_cache()
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
|
|
@ -29,6 +29,7 @@ from litellm.proxy.proxy_server import (
|
|||
router,
|
||||
save_worker_config,
|
||||
startup_event,
|
||||
asyncio,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -39,7 +40,7 @@ save_worker_config(
|
|||
alias=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
debug=False,
|
||||
debug=True,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
request_timeout=600,
|
||||
|
@ -51,24 +52,38 @@ save_worker_config(
|
|||
save=False,
|
||||
use_queue=False,
|
||||
)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
await startup_event()
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
policy = asyncio.WindowsSelectorEventLoopPolicy()
|
||||
res = policy.new_event_loop()
|
||||
asyncio.set_event_loop(res)
|
||||
res._close = res.close
|
||||
res.close = lambda: None
|
||||
|
||||
yield res
|
||||
|
||||
res._close()
|
||||
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables, initialize
|
||||
|
||||
cleanup_router_config_variables()
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
cleanup_router_config_variables() # rest proxy before test
|
||||
|
||||
asyncio.run(initialize(config=config_fp, debug=True))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_add_new_key(client):
|
||||
|
@ -79,7 +94,7 @@ def test_add_new_key(client):
|
|||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": "20m",
|
||||
}
|
||||
print("testing proxy server")
|
||||
print("testing proxy server - test_add_new_key")
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
|
@ -121,7 +136,7 @@ def test_update_new_key(client):
|
|||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": "20m",
|
||||
}
|
||||
print("testing proxy server")
|
||||
print("testing proxy server-test_update_new_key")
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
|
|
|
@ -98,6 +98,73 @@ def test_init_clients_basic():
|
|||
# test_init_clients_basic()
|
||||
|
||||
|
||||
def test_init_clients_basic_azure_cloudflare():
|
||||
# init azure + cloudflare
|
||||
# init OpenAI gpt-3.5
|
||||
# init OpenAI text-embedding
|
||||
# init OpenAI comptaible - Mistral/mistral-medium
|
||||
# init OpenAI compatible - xinference/bge
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
print("Test basic client init")
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-cloudflare",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": "https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/openai-gpt-4-test-v-1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-openai",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"litellm_params": {
|
||||
"model": "text-embedding-ada-002",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "mistral",
|
||||
"litellm_params": {
|
||||
"model": "mistral/mistral-tiny",
|
||||
"api_key": os.getenv("MISTRAL_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "bge-base-en",
|
||||
"litellm_params": {
|
||||
"model": "xinference/bge-base-en",
|
||||
"api_base": "http://127.0.0.1:9997/v1",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
for elem in router.model_list:
|
||||
model_id = elem["model_info"]["id"]
|
||||
assert router.cache.get_cache(f"{model_id}_client") is not None
|
||||
assert router.cache.get_cache(f"{model_id}_async_client") is not None
|
||||
assert router.cache.get_cache(f"{model_id}_stream_client") is not None
|
||||
assert router.cache.get_cache(f"{model_id}_stream_async_client") is not None
|
||||
print("PASSED !")
|
||||
|
||||
# see if we can init clients without timeout or max retries set
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_init_clients_basic_azure_cloudflare()
|
||||
|
||||
|
||||
def test_timeouts_router():
|
||||
"""
|
||||
Test the timeouts of the router with multiple clients. This HASas to raise a timeout error
|
||||
|
|
137
litellm/tests/test_router_policy_violation.py
Normal file
137
litellm/tests/test_router_policy_violation.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
#### What this tests ####
|
||||
# This tests if the router sends back a policy violation, without retries
|
||||
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
success: bool = False
|
||||
failure: bool = False
|
||||
previous_models: int = 0
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
print(
|
||||
f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}"
|
||||
)
|
||||
self.previous_models += len(
|
||||
kwargs["litellm_params"]["metadata"]["previous_models"]
|
||||
) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||
print(f"self.previous_models: {self.previous_models}")
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(
|
||||
f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}"
|
||||
)
|
||||
|
||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Stream")
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
||||
|
||||
kwargs = {
|
||||
"model": "azure/gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "vorrei vedere la cosa più bella ad Ercolano. Qual’è?",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_fallbacks():
|
||||
litellm.set_verbose = False
|
||||
model_list = [
|
||||
{ # list of model deployments
|
||||
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-functioncalling",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 1000000,
|
||||
"rpm": 9000,
|
||||
},
|
||||
]
|
||||
|
||||
router = Router(
|
||||
model_list=model_list,
|
||||
num_retries=3,
|
||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||
# context_window_fallbacks=[
|
||||
# {"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]},
|
||||
# {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]},
|
||||
# ],
|
||||
set_verbose=False,
|
||||
)
|
||||
customHandler = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
try:
|
||||
response = await router.acompletion(**kwargs)
|
||||
pytest.fail(
|
||||
f"An exception occurred: {e}"
|
||||
) # should've raised azure policy error
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
await asyncio.sleep(
|
||||
0.05
|
||||
) # allow a delay as success_callbacks are on a separate thread
|
||||
assert customHandler.previous_models == 0 # 0 retries, 0 fallback
|
||||
router.reset()
|
||||
finally:
|
||||
router.reset()
|
|
@ -306,6 +306,8 @@ def test_completion_ollama_hosted_stream():
|
|||
model="ollama/phi",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
num_retries=3,
|
||||
timeout=90,
|
||||
api_base="https://test-ollama-endpoint.onrender.com",
|
||||
stream=True,
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import sys, re, binascii, struct
|
||||
import litellm
|
||||
import dotenv, json, traceback, threading, base64
|
||||
import dotenv, json, traceback, threading, base64, ast
|
||||
import subprocess, os
|
||||
import litellm, openai
|
||||
import itertools
|
||||
|
@ -1975,7 +1975,10 @@ def client(original_function):
|
|||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) == True
|
||||
or kwargs.get("cache", {}).get("no-cache", False) != True
|
||||
or (
|
||||
kwargs.get("cache", None) is not None
|
||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||
)
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose(f"INSIDE CHECKING CACHE")
|
||||
|
@ -2737,6 +2740,8 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
completion_tokens_cost_usd_dollar = 0
|
||||
model_cost_ref = litellm.model_cost
|
||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||
print_verbose(f"Looking up model={model} in model_cost_map")
|
||||
|
||||
if model in model_cost_ref:
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
|
@ -2746,6 +2751,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif "ft:gpt-3.5-turbo" in model:
|
||||
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||
|
@ -2756,6 +2762,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif model in litellm.azure_llms:
|
||||
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
|
||||
model = litellm.azure_llms[model]
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
|
@ -2764,19 +2771,29 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
else:
|
||||
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
|
||||
input_cost_sum = 0
|
||||
output_cost_sum = 0
|
||||
model_cost_ref = litellm.model_cost
|
||||
for model in model_cost_ref:
|
||||
input_cost_sum += model_cost_ref[model]["input_cost_per_token"]
|
||||
output_cost_sum += model_cost_ref[model]["output_cost_per_token"]
|
||||
avg_input_cost = input_cost_sum / len(model_cost_ref.keys())
|
||||
avg_output_cost = output_cost_sum / len(model_cost_ref.keys())
|
||||
prompt_tokens_cost_usd_dollar = avg_input_cost * prompt_tokens
|
||||
completion_tokens_cost_usd_dollar = avg_output_cost * completion_tokens
|
||||
elif model in litellm.azure_embedding_models:
|
||||
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||
model = litellm.azure_embedding_models[model]
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
else:
|
||||
# if model is not in model_prices_and_context_window.json. Raise an exception-let users know
|
||||
error_str = f"Model not in model_prices_and_context_window.json. You passed model={model}\n"
|
||||
raise litellm.exceptions.NotFoundError( # type: ignore
|
||||
message=error_str,
|
||||
model=model,
|
||||
response=httpx.Response(
|
||||
status_code=404,
|
||||
content=error_str,
|
||||
request=httpx.request(method="cost_per_token", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
llm_provider="",
|
||||
)
|
||||
|
||||
|
||||
def completion_cost(
|
||||
|
@ -2818,8 +2835,10 @@ def completion_cost(
|
|||
completion_tokens = 0
|
||||
if completion_response is not None:
|
||||
# get input/output tokens from completion_response
|
||||
prompt_tokens = completion_response["usage"]["prompt_tokens"]
|
||||
completion_tokens = completion_response["usage"]["completion_tokens"]
|
||||
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = completion_response.get("usage", {}).get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
model = (
|
||||
model or completion_response["model"]
|
||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
||||
|
@ -2829,6 +2848,10 @@ def completion_cost(
|
|||
elif len(prompt) > 0:
|
||||
prompt_tokens = token_counter(model=model, text=prompt)
|
||||
completion_tokens = token_counter(model=model, text=completion)
|
||||
if model == None:
|
||||
raise ValueError(
|
||||
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
||||
)
|
||||
|
||||
# Calculate cost based on prompt_tokens, completion_tokens
|
||||
if "togethercomputer" in model or "together_ai" in model:
|
||||
|
@ -2849,8 +2872,7 @@ def completion_cost(
|
|||
)
|
||||
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM: Excepton when cost calculating {str(e)}")
|
||||
return 0.0 # this should not block a users execution path
|
||||
raise e
|
||||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
|
@ -4081,11 +4103,11 @@ def get_llm_provider(
|
|||
print() # noqa
|
||||
error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers"
|
||||
# maps to openai.NotFoundError, this is raised when openai does not recognize the llm
|
||||
raise litellm.exceptions.NotFoundError( # type: ignore
|
||||
raise litellm.exceptions.BadRequestError( # type: ignore
|
||||
message=error_str,
|
||||
model=model,
|
||||
response=httpx.Response(
|
||||
status_code=404,
|
||||
status_code=400,
|
||||
content=error_str,
|
||||
request=httpx.request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
|
@ -4915,6 +4937,9 @@ async def convert_to_streaming_response_async(response_object: Optional[dict] =
|
|||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "created" in response_object:
|
||||
model_response_object.created = response_object["created"]
|
||||
|
||||
if "system_fingerprint" in response_object:
|
||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||
|
||||
|
@ -4959,6 +4984,9 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
|
|||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "created" in response_object:
|
||||
model_response_object.created = response_object["created"]
|
||||
|
||||
if "system_fingerprint" in response_object:
|
||||
model_response_object.system_fingerprint = response_object["system_fingerprint"]
|
||||
|
||||
|
@ -5014,6 +5042,9 @@ def convert_to_model_response_object(
|
|||
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
|
||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||
|
||||
if "created" in response_object:
|
||||
model_response_object.created = response_object["created"]
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
|
@ -6621,7 +6652,7 @@ def _is_base64(s):
|
|||
|
||||
def get_secret(
|
||||
secret_name: str,
|
||||
default_value: Optional[str] = None,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
if secret_name.startswith("os.environ/"):
|
||||
|
@ -6672,9 +6703,24 @@ def get_secret(
|
|||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
else:
|
||||
return os.environ.get(secret_name)
|
||||
secret = os.environ.get(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
|
|
|
@ -111,6 +111,13 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"text-embedding-ada-002-v2": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"256-x-256/dall-e-2": {
|
||||
"mode": "image_generation",
|
||||
"input_cost_per_pixel": 0.00000024414,
|
||||
|
@ -242,6 +249,13 @@
|
|||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/ada": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"azure/text-embedding-ada-002": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.16.13"
|
||||
version = "1.16.14"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
@ -59,7 +59,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.16.13"
|
||||
version = "1.16.14"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
28
retry_push.sh
Normal file
28
retry_push.sh
Normal file
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
|
||||
retry_count=0
|
||||
max_retries=3
|
||||
exit_code=1
|
||||
|
||||
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
||||
do
|
||||
retry_count=$((retry_count+1))
|
||||
echo "Attempt $retry_count..."
|
||||
|
||||
# Run the Prisma db push command
|
||||
prisma db push --accept-data-loss
|
||||
|
||||
exit_code=$?
|
||||
|
||||
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
||||
echo "Retrying in 10 seconds..."
|
||||
sleep 10
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $exit_code -ne 0 ]; then
|
||||
echo "Unable to push database changes after $max_retries retries."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Database push successful!"
|
33
schema.prisma
Normal file
33
schema.prisma
Normal file
|
@ -0,0 +1,33 @@
|
|||
datasource client {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
}
|
||||
|
||||
model LiteLLM_UserTable {
|
||||
user_id String @unique
|
||||
max_budget Float?
|
||||
spend Float @default(0.0)
|
||||
user_email String?
|
||||
}
|
||||
|
||||
// required for token gen
|
||||
model LiteLLM_VerificationToken {
|
||||
token String @unique
|
||||
spend Float @default(0.0)
|
||||
expires DateTime?
|
||||
models String[]
|
||||
aliases Json @default("{}")
|
||||
config Json @default("{}")
|
||||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
metadata Json @default("{}")
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
param_name String @id
|
||||
param_value Json?
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue