Merge remote-tracking branch 'src/main'

This commit is contained in:
Sébastien Campion 2024-01-27 19:24:35 +01:00
commit 4dd18b553a
29 changed files with 550 additions and 170 deletions

View file

@ -34,13 +34,6 @@ jobs:
with: with:
push: true push: true
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }} tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
-
name: Build and push litellm-ui image
uses: docker/build-push-action@v5
with:
push: true
file: ui/Dockerfile
tags: litellm/litellm-ui:${{ github.event.inputs.tag || 'latest' }}
- -
name: Build and push litellm-database image name: Build and push litellm-database image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5

View file

@ -13,8 +13,8 @@ response = embedding(model='text-embedding-ada-002', input=["good morning from l
- `model`: *string* - ID of the model to use. `model='text-embedding-ada-002'` - `model`: *string* - ID of the model to use. `model='text-embedding-ada-002'`
- `input`: *array* - Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less. - `input`: *string or array* - Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less.
``` ```python
input=["good morning from litellm"] input=["good morning from litellm"]
``` ```
@ -22,7 +22,11 @@ input=["good morning from litellm"]
- `user`: *string (optional)* A unique identifier representing your end-user, - `user`: *string (optional)* A unique identifier representing your end-user,
- `timeout`: *integer* - The maximum time, in seconds, to wait for the API to respond. Defaults to 600 seconds (10 minutes). - `dimensions`: *integer (Optional)* The number of dimensions the resulting output embeddings should have. Only supported in OpenAI/Azure text-embedding-3 and later models.
- `encoding_format`: *string (Optional)* The format to return the embeddings in. Can be either `"float"` or `"base64"`. Defaults to `encoding_format="float"`
- `timeout`: *integer (Optional)* - The maximum time, in seconds, to wait for the API to respond. Defaults to 600 seconds (10 minutes).
- `api_base`: *string (optional)* - The api endpoint you want to call the model with - `api_base`: *string (optional)* - The api endpoint you want to call the model with
@ -66,7 +70,12 @@ input=["good morning from litellm"]
from litellm import embedding from litellm import embedding
import os import os
os.environ['OPENAI_API_KEY'] = "" os.environ['OPENAI_API_KEY'] = ""
response = embedding('text-embedding-ada-002', input=["good morning from litellm"]) response = embedding(
model="text-embedding-3-small",
input=["good morning from litellm", "this is another item"],
metadata={"anything": "good day"},
dimensions=5 # Only supported in text-embedding-3 and later models.
)
``` ```
| Model Name | Function Call | Required OS Variables | | Model Name | Function Call | Required OS Variables |

View file

@ -1,6 +1,13 @@
# Slack Alerting # Slack Alerting
Get alerts for failed db read/writes, hanging api calls, failed api calls. Get alerts for:
- hanging LLM api calls
- failed LLM api calls
- slow LLM api calls
- budget Tracking per key/user:
- When a User/Key crosses their Budget
- When a User/Key is 15% away from crossing their Budget
- failed db read/writes
## Quick Start ## Quick Start

View file

@ -605,6 +605,49 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}") print(f"response: {response}")
``` ```
## Custom Callbacks - Track API Key, API Endpoint, Model Used
If you need to track the api_key, api endpoint, model, custom_llm_provider used for each completion call, you can setup a [custom callback](https://docs.litellm.ai/docs/observability/custom_callback)
### Usage
```python
import litellm
from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success")
print("kwargs=", kwargs)
litellm_params= kwargs.get("litellm_params")
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
custom_llm_provider= litellm_params.get("custom_llm_provider")
response_cost = kwargs.get("response_cost")
# print the values
print("api_key=", api_key)
print("api_base=", api_base)
print("custom_llm_provider=", custom_llm_provider)
print("response_cost=", response_cost)
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure")
print("kwargs=")
customHandler = MyCustomHandler()
litellm.callbacks = [customHandler]
# Init Router
router = Router(model_list=model_list, routing_strategy="simple-shuffle")
# router completion call
response = router.completion(
model="gpt-3.5-turbo",
messages=[{ "role": "user", "content": "Hi who are you"}]
)
```
## Deploy Router ## Deploy Router

View file

@ -1,3 +1,12 @@
# +-----------------------------------------------+
# | |
# | NOT PROXY BUDGET MANAGER |
# | proxy budget manager is in proxy_server.py |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import os, json, time import os, json, time
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
@ -16,7 +25,7 @@ class BudgetManager:
self.client_type = client_type self.client_type = client_type
self.project_name = project_name self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai" self.api_base = api_base or "https://api.litellm.ai"
self.headers = headers or {'Content-Type': 'application/json'} self.headers = headers or {"Content-Type": "application/json"}
## load the data or init the initial dictionaries ## load the data or init the initial dictionaries
self.load_data() self.load_data()

View file

@ -659,9 +659,16 @@ def completion(
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = response_metadata.get(
completion_tokens = len( "x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
encoding.encode(model_response["choices"][0]["message"].get("content", "")) )
completion_tokens = response_metadata.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
),
) )
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
@ -672,6 +679,8 @@ def completion(
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
model_response.usage = usage model_response.usage = usage
model_response._hidden_params["region_name"] = client.meta.region_name
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
return model_response return model_response
except BedrockError as e: except BedrockError as e:
exception_mapping_worked = True exception_mapping_worked = True

View file

@ -718,8 +718,22 @@ class OpenAIChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e raise e
except Exception as e: except Exception as e:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e)) raise OpenAIError(status_code=e.status_code, message=str(e))
else: else:

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
@ -586,6 +585,10 @@ def completion(
) )
if model_response is not None and hasattr(model_response, "_hidden_params"): if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
model_response._hidden_params["region_name"] = kwargs.get(
"aws_region_name", None
) # support region-based pricing for bedrock
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None: if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model( litellm.register_model(
@ -2224,6 +2227,7 @@ def embedding(
model, model,
input=[], input=[],
# Optional params # Optional params
dimensions: Optional[int] = None,
timeout=600, # default to 10 minutes timeout=600, # default to 10 minutes
# set api_base, api_version, api_key # set api_base, api_version, api_key
api_base: Optional[str] = None, api_base: Optional[str] = None,
@ -2244,6 +2248,7 @@ def embedding(
Parameters: Parameters:
- model: The embedding model to use. - model: The embedding model to use.
- input: The input for which embeddings are to be generated. - input: The input for which embeddings are to be generated.
- dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
- timeout: The timeout value for the API call, default 10 mins - timeout: The timeout value for the API call, default 10 mins
- litellm_call_id: The call ID for litellm logging. - litellm_call_id: The call ID for litellm logging.
- litellm_logging_obj: The litellm logging object. - litellm_logging_obj: The litellm logging object.
@ -2277,6 +2282,7 @@ def embedding(
output_cost_per_second = kwargs.get("output_cost_per_second", None) output_cost_per_second = kwargs.get("output_cost_per_second", None)
openai_params = [ openai_params = [
"user", "user",
"dimensions",
"request_timeout", "request_timeout",
"api_base", "api_base",
"api_version", "api_version",
@ -2345,7 +2351,9 @@ def embedding(
api_key=api_key, api_key=api_key,
) )
optional_params = get_optional_params_embeddings( optional_params = get_optional_params_embeddings(
model=model,
user=user, user=user,
dimensions=dimensions,
encoding_format=encoding_format, encoding_format=encoding_format,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
**non_default_params, **non_default_params,
@ -3067,7 +3075,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
**non_default_params, **non_default_params,
) )
logging = litellm_logging_obj logging: Logging = litellm_logging_obj
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
user=user, user=user,

View file

@ -140,6 +140,7 @@ class GenerateRequestBase(LiteLLMBase):
class GenerateKeyRequest(GenerateRequestBase): class GenerateKeyRequest(GenerateRequestBase):
key_alias: Optional[str] = None
duration: Optional[str] = "1h" duration: Optional[str] = "1h"
aliases: Optional[dict] = {} aliases: Optional[dict] = {}
config: Optional[dict] = {} config: Optional[dict] = {}
@ -304,6 +305,8 @@ class ConfigYAML(LiteLLMBase):
class LiteLLM_VerificationToken(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase):
token: str token: str
key_name: Optional[str] = None
key_alias: Optional[str] = None
spend: float = 0.0 spend: float = 0.0
max_budget: Optional[float] = None max_budget: Optional[float] = None
expires: Union[str, None] expires: Union[str, None]
@ -346,11 +349,12 @@ class LiteLLM_SpendLogs(LiteLLMBase):
model: Optional[str] = "" model: Optional[str] = ""
call_type: str call_type: str
spend: Optional[float] = 0.0 spend: Optional[float] = 0.0
total_tokens: Optional[int] = 0
prompt_tokens: Optional[int] = 0
completion_tokens: Optional[int] = 0
startTime: Union[str, datetime, None] startTime: Union[str, datetime, None]
endTime: Union[str, datetime, None] endTime: Union[str, datetime, None]
user: Optional[str] = "" user: Optional[str] = ""
modelParameters: Optional[Json] = {}
usage: Optional[Json] = {}
metadata: Optional[Json] = {} metadata: Optional[Json] = {}
cache_hit: Optional[str] = "False" cache_hit: Optional[str] = "False"
cache_key: Optional[str] = None cache_key: Optional[str] = None

View file

@ -5,6 +5,7 @@ from litellm.proxy._types import (
LiteLLM_Config, LiteLLM_Config,
LiteLLM_UserTable, LiteLLM_UserTable,
) )
from litellm.proxy.utils import hash_token
from litellm import get_secret from litellm import get_secret
from typing import Any, List, Literal, Optional, Union from typing import Any, List, Literal, Optional, Union
import json import json
@ -187,6 +188,8 @@ class DynamoDBWrapper(CustomDB):
table = client.table(self.database_arguments.spend_table_name) table = client.table(self.database_arguments.spend_table_name)
for k, v in value.items(): for k, v in value.items():
if k == "token" and value[k].startswith("sk-"):
value[k] = hash_token(token=v)
if isinstance(v, datetime): if isinstance(v, datetime):
value[k] = v.isoformat() value[k] = v.isoformat()
@ -229,6 +232,10 @@ class DynamoDBWrapper(CustomDB):
table = client.table(self.database_arguments.config_table_name) table = client.table(self.database_arguments.config_table_name)
key_name = "param_name" key_name = "param_name"
if key_name == "token" and key.startswith("sk-"):
# ensure it's hashed
key = hash_token(token=key)
response = await table.get_item({key_name: key}) response = await table.get_item({key_name: key})
new_response: Any = None new_response: Any = None
@ -308,6 +315,8 @@ class DynamoDBWrapper(CustomDB):
# Convert datetime object to ISO8601 string # Convert datetime object to ISO8601 string
if isinstance(v, datetime): if isinstance(v, datetime):
v = v.isoformat() v = v.isoformat()
if k == "token" and value[k].startswith("sk-"):
value[k] = hash_token(token=v)
# Accumulate updates # Accumulate updates
actions.append((F(k), Value(value=v))) actions.append((F(k), Value(value=v)))

View file

@ -11,6 +11,12 @@ model_list:
output_cost_per_token: 0.00003 output_cost_per_token: 0.00003
max_tokens: 4096 max_tokens: 4096
base_model: gpt-3.5-turbo base_model: gpt-3.5-turbo
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: gpt-vision - model_name: gpt-vision
litellm_params: litellm_params:
model: azure/gpt-4-vision model: azure/gpt-4-vision
@ -61,7 +67,7 @@ model_list:
litellm_settings: litellm_settings:
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}] fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
success_callback: ['langfuse'] success_callback: ['langfuse']
max_budget: 0.025 # global budget for proxy max_budget: 10 # global budget for proxy
budget_duration: 30d # global budget duration, will reset after 30d budget_duration: 30d # global budget duration, will reset after 30d
# cache: True # cache: True
# setting callback class # setting callback class

View file

@ -75,6 +75,7 @@ from litellm.proxy.utils import (
send_email, send_email,
get_logging_payload, get_logging_payload,
reset_budget, reset_budget,
hash_token,
) )
from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic import pydantic
@ -243,6 +244,8 @@ async def user_api_key_auth(
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ### ### LITELLM-DEFINED AUTH FUNCTION ###
if isinstance(api_key, str):
assert api_key.startswith("sk-") # prevent token hashes from being used
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
return UserAPIKeyAuth(api_key=api_key) return UserAPIKeyAuth(api_key=api_key)
@ -288,8 +291,9 @@ async def user_api_key_auth(
raise Exception("No connected db.") raise Exception("No connected db.")
## check for cache hit (In-Memory Cache) ## check for cache hit (In-Memory Cache)
if api_key.startswith("sk-"):
api_key = hash_token(token=api_key)
valid_token = user_api_key_cache.get_cache(key=api_key) valid_token = user_api_key_cache.get_cache(key=api_key)
verbose_proxy_logger.debug(f"valid_token from cache: {valid_token}")
if valid_token is None: if valid_token is None:
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
@ -482,10 +486,10 @@ async def user_api_key_auth(
) )
# Token passed all checks # Token passed all checks
# Add token to cache
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
api_key = valid_token.token api_key = valid_token.token
# Add hashed token to cache
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
""" """
@ -520,7 +524,10 @@ async def user_api_key_auth(
# check if user can access this route # check if user can access this route
query_params = request.query_params query_params = request.query_params
key = query_params.get("key") key = query_params.get("key")
if prisma_client.hash_token(token=key) != api_key: if (
key is not None
and prisma_client.hash_token(token=key) != api_key
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info", detail="user not allowed to access this key's info",
@ -748,6 +755,9 @@ async def update_database(
### UPDATE KEY SPEND ### ### UPDATE KEY SPEND ###
async def _update_key_db(): async def _update_key_db():
verbose_proxy_logger.debug(
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
)
if prisma_client is not None: if prisma_client is not None:
# Fetch the existing cost for the given token # Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token) existing_spend_obj = await prisma_client.get_data(token=token)
@ -1239,6 +1249,7 @@ async def generate_key_helper_fn(
rpm_limit: Optional[int] = None, rpm_limit: Optional[int] = None,
query_type: Literal["insert_data", "update_data"] = "insert_data", query_type: Literal["insert_data", "update_data"] = "insert_data",
update_key_values: Optional[dict] = None, update_key_values: Optional[dict] = None,
key_alias: Optional[str] = None,
): ):
global prisma_client, custom_db_client global prisma_client, custom_db_client
@ -1312,6 +1323,7 @@ async def generate_key_helper_fn(
} }
key_data = { key_data = {
"token": token, "token": token,
"key_alias": key_alias,
"expires": expires, "expires": expires,
"models": models, "models": models,
"aliases": aliases_json, "aliases": aliases_json,
@ -1327,6 +1339,8 @@ async def generate_key_helper_fn(
"budget_duration": key_budget_duration, "budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at, "budget_reset_at": key_reset_at,
} }
if general_settings.get("allow_user_auth", False) == True:
key_data["key_name"] = f"sk-...{token[-4:]}"
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
@ -2451,10 +2465,10 @@ async def delete_key_fn(data: DeleteKeyRequest):
Delete a key from the key management system. Delete a key from the key management system.
Parameters:: Parameters::
- keys (List[str]): A list of keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw"]} - keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
Returns: Returns:
- deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw"]} - deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
Raises: Raises:
@ -2491,14 +2505,39 @@ async def delete_key_fn(data: DeleteKeyRequest):
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
) )
async def info_key_fn( async def info_key_fn(
key: str = fastapi.Query(..., description="Key in the request parameters"), key: Optional[str] = fastapi.Query(
default=None, description="Key in the request parameters"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
"""
Retrieve information about a key.
Parameters:
key: Optional[str] = Query parameter representing the key in the request
user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key
Returns:
Dict containing the key and its associated information
Example Curl:
```
curl -X GET "http://0.0.0.0:8000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \
-H "Authorization: Bearer sk-1234"
```
Example Curl - if no key is passed, it will use the Key Passed in Authorization Header
```
curl -X GET "http://0.0.0.0:8000/key/info" \
-H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA"
```
"""
global prisma_client global prisma_client
try: try:
if prisma_client is None: if prisma_client is None:
raise Exception( raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
) )
if key == None:
key = user_api_key_dict.api_key
key_info = await prisma_client.get_data(token=key) key_info = await prisma_client.get_data(token=key)
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try: try:

View file

@ -7,6 +7,7 @@ generator client {
provider = "prisma-client-py" provider = "prisma-client-py"
} }
// Track spend, rate limit, budget Users
model LiteLLM_UserTable { model LiteLLM_UserTable {
user_id String @unique user_id String @unique
team_id String? team_id String?
@ -21,9 +22,11 @@ model LiteLLM_UserTable {
budget_reset_at DateTime? budget_reset_at DateTime?
} }
// required for token gen // Generate Tokens for Proxy
model LiteLLM_VerificationToken { model LiteLLM_VerificationToken {
token String @unique token String @unique
key_name String?
key_alias String?
spend Float @default(0.0) spend Float @default(0.0)
expires DateTime? expires DateTime?
models String[] models String[]
@ -40,22 +43,25 @@ model LiteLLM_VerificationToken {
budget_reset_at DateTime? budget_reset_at DateTime?
} }
// store proxy config.yaml
model LiteLLM_Config { model LiteLLM_Config {
param_name String @id param_name String @id
param_value Json? param_value Json?
} }
// View spend, model, api_key per request
model LiteLLM_SpendLogs { model LiteLLM_SpendLogs {
request_id String @unique request_id String @unique
call_type String call_type String
api_key String @default ("") api_key String @default ("")
spend Float @default(0.0) spend Float @default(0.0)
total_tokens Int @default(0)
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
startTime DateTime // Assuming start_time is a DateTime field startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field endTime DateTime // Assuming end_time is a DateTime field
model String @default("") model String @default("")
user String @default("") user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
usage Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
cache_hit String @default("") cache_hit String @default("")
cache_key String @default("") cache_key String @default("")

View file

@ -198,7 +198,14 @@ class ProxyLogging:
max_budget = user_info["max_budget"] max_budget = user_info["max_budget"]
spend = user_info["spend"] spend = user_info["spend"]
user_email = user_info["user_email"] user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: {max_budget}\nSpend: {spend}\nUser Email: {user_email}""" user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
else: else:
user_info = str(user_info) user_info = str(user_info)
# percent of max_budget left to spend # percent of max_budget left to spend
@ -814,7 +821,13 @@ class PrismaClient:
Allow user to delete a key(s) Allow user to delete a key(s)
""" """
try: try:
hashed_tokens = [self.hash_token(token=token) for token in tokens] hashed_tokens = []
for token in tokens:
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
else:
hashed_token = token
hashed_tokens.append(hashed_token)
await self.db.litellm_verificationtoken.delete_many( await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}} where={"token": {"in": hashed_tokens}}
) )
@ -1060,10 +1073,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
metadata = ( metadata = (
litellm_params.get("metadata", {}) or {} litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None ) # if litellm_params['metadata'] == None
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion") call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False) cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"] usage = response_obj["usage"]
if type(usage) == litellm.Usage:
usage = dict(usage)
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
@ -1091,10 +1105,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
"endTime": end_time, "endTime": end_time,
"model": kwargs.get("model", ""), "model": kwargs.get("model", ""),
"user": kwargs.get("user", ""), "user": kwargs.get("user", ""),
"modelParameters": optional_params,
"usage": usage,
"metadata": metadata, "metadata": metadata,
"cache_key": cache_key, "cache_key": cache_key,
"total_tokens": usage.get("total_tokens", 0),
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
} }
json_fields = [ json_fields = [
@ -1119,8 +1134,6 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = payload[param].model_dump_json() payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse: if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json() payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == litellm.Usage:
payload[param] = payload[param].model_dump_json()
else: else:
payload[param] = json.dumps(payload[param]) payload[param] = json.dumps(payload[param])

View file

@ -723,8 +723,8 @@ def test_cache_override():
print(f"Embedding 2 response time: {end_time - start_time} seconds") print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert ( assert (
end_time - start_time > 0.1 end_time - start_time > 0.05
) # ensure 2nd response comes in over 0.1s. This should not be cached. ) # ensure 2nd response comes in over 0.05s. This should not be cached.
# test_cache_override() # test_cache_override()

View file

@ -124,7 +124,7 @@ def test_cost_azure_gpt_35():
) )
test_cost_azure_gpt_35() # test_cost_azure_gpt_35()
def test_cost_azure_embedding(): def test_cost_azure_embedding():
@ -165,3 +165,71 @@ def test_cost_openai_image_gen():
model="dall-e-2", size="1024-x-1024", quality="standard", n=1 model="dall-e-2", size="1024-x-1024", quality="standard", n=1
) )
assert cost == 0.019922944 assert cost == 0.019922944
def test_cost_bedrock_pricing():
"""
- get pricing specific to region for a model
"""
from litellm import ModelResponse, Choices, Message
from litellm.utils import Usage
litellm.set_verbose = True
input_tokens = litellm.token_counter(
model="bedrock/anthropic.claude-instant-v1",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
print(f"input_tokens: {input_tokens}")
output_tokens = litellm.token_counter(
model="bedrock/anthropic.claude-instant-v1",
text="It's all going well",
count_response_tokens=True,
)
print(f"output_tokens: {output_tokens}")
resp = ModelResponse(
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
choices=[
Choices(
finish_reason=None,
index=0,
message=Message(
content="It's all going well",
role="assistant",
),
)
],
created=1700775391,
model="anthropic.claude-instant-v1",
object="chat.completion",
system_fingerprint=None,
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
),
)
resp._hidden_params = {
"custom_llm_provider": "bedrock",
"region_name": "ap-northeast-1",
}
cost = litellm.completion_cost(
model="anthropic.claude-instant-v1",
completion_response=resp,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
predicted_cost = input_tokens * 0.00000223 + 0.00000755 * output_tokens
assert cost == predicted_cost
def test_cost_bedrock_pricing_actual_calls():
litellm.set_verbose = True
model = "anthropic.claude-instant-v1"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = litellm.completion(model=model, messages=messages)
assert response._hidden_params["region_name"] is not None
cost = litellm.completion_cost(
completion_response=response,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
assert cost > 0

View file

@ -53,9 +53,9 @@ model_list:
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview api_version: 2023-07-01-preview
model: azure/azure-embedding-model model: azure/azure-embedding-model
model_name: azure-embedding-model
model_info: model_info:
mode: "embedding" mode: embedding
model_name: azure-embedding-model
- litellm_params: - litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
model_info: model_info:
@ -80,43 +80,49 @@ model_list:
description: this is a test openai model description: this is a test openai model
id: 9b1ef341-322c-410a-8992-903987fef439 id: 9b1ef341-322c-410a-8992-903987fef439
model_name: test_openai_models model_name: test_openai_models
- model_name: amazon-embeddings - litellm_params:
litellm_params: model: bedrock/amazon.titan-embed-text-v1
model: "bedrock/amazon.titan-embed-text-v1"
model_info: model_info:
mode: embedding mode: embedding
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" model_name: amazon-embeddings
litellm_params: - litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
model_info: model_info:
mode: embedding mode: embedding
- model_name: dall-e-3 model_name: GPT-J 6B - Sagemaker Text Embedding (Internal)
litellm_params: - litellm_params:
model: dall-e-3 model: dall-e-3
model_info: model_info:
mode: image_generation mode: image_generation
- model_name: dall-e-3 model_name: dall-e-3
litellm_params: - litellm_params:
model: "azure/dall-e-3-test" api_base: os.environ/AZURE_SWEDEN_API_BASE
api_version: "2023-12-01-preview" api_key: os.environ/AZURE_SWEDEN_API_KEY
api_base: "os.environ/AZURE_SWEDEN_API_BASE" api_version: 2023-12-01-preview
api_key: "os.environ/AZURE_SWEDEN_API_KEY" model: azure/dall-e-3-test
model_info: model_info:
mode: image_generation mode: image_generation
- model_name: dall-e-2 model_name: dall-e-3
litellm_params: - litellm_params:
model: "azure/" api_base: os.environ/AZURE_API_BASE
api_version: "2023-06-01-preview" api_key: os.environ/AZURE_API_KEY
api_base: "os.environ/AZURE_API_BASE" api_version: 2023-06-01-preview
api_key: "os.environ/AZURE_API_KEY" model: azure/
model_info: model_info:
mode: image_generation mode: image_generation
- model_name: text-embedding-ada-002 model_name: dall-e-2
litellm_params: - litellm_params:
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: 2023-07-01-preview
model: azure/azure-embedding-model model: azure/azure-embedding-model
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
api_version: "2023-07-01-preview"
model_info: model_info:
base_model: text-embedding-ada-002
mode: embedding mode: embedding
base_model: text-embedding-ada-002 model_name: text-embedding-ada-002
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
id: 34cb2419-7c63-44ae-a189-53f1d1ce5953
model_name: test_openai_models

View file

@ -819,47 +819,49 @@ async def test_async_embedding_azure_caching():
# Image Generation # Image Generation
# ## Test OpenAI + Sync ## Test OpenAI + Sync
# def test_image_generation_openai(): def test_image_generation_openai():
# try: try:
# customHandler_success = CompletionCustomHandler() customHandler_success = CompletionCustomHandler()
# customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success] # litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True # litellm.set_verbose = True
# response = litellm.image_generation( # response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3" # prompt="A cute baby sea otter", model="dall-e-3"
# ) # )
# print(f"response: {response}") # print(f"response: {response}")
# assert len(response.data) > 0 # assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}") # print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}") # print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0 # assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success # assert len(customHandler_success.states) == 3 # pre, post, success
# # test failure callback # test failure callback
# litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
# try: try:
# response = litellm.image_generation( response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-4" prompt="A cute baby sea otter",
# ) model="dall-e-2",
# except: api_key="my-bad-api-key",
# pass )
# print(f"customHandler_failure.errors: {customHandler_failure.errors}") except:
# print(f"customHandler_failure.states: {customHandler_failure.states}") pass
# assert len(customHandler_failure.errors) == 0 print(f"customHandler_failure.errors: {customHandler_failure.errors}")
# assert len(customHandler_failure.states) == 3 # pre, post, failure print(f"customHandler_failure.states: {customHandler_failure.states}")
# except litellm.RateLimitError as e: assert len(customHandler_failure.errors) == 0
# pass assert len(customHandler_failure.states) == 3 # pre, post, failure
# except litellm.ContentPolicyViolationError: except litellm.RateLimitError as e:
# pass # OpenAI randomly raises these errors - skip when they occur pass
# except Exception as e: except litellm.ContentPolicyViolationError:
# pytest.fail(f"An exception occurred - {str(e)}") pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai() test_image_generation_openai()
## Test OpenAI + Async ## Test OpenAI + Async
## Test Azure + Sync ## Test Azure + Sync

View file

@ -64,7 +64,9 @@ def test_openai_embedding_3():
model="text-embedding-3-small", model="text-embedding-3-small",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
metadata={"anything": "good day"}, metadata={"anything": "good day"},
dimensions=5,
) )
print(f"response:", response)
litellm_response = dict(response) litellm_response = dict(response)
litellm_response_keys = set(litellm_response.keys()) litellm_response_keys = set(litellm_response.keys())
litellm_response_keys.discard("_response_ms") litellm_response_keys.discard("_response_ms")
@ -80,6 +82,7 @@ def test_openai_embedding_3():
response = client.embeddings.create( response = client.embeddings.create(
model="text-embedding-3-small", model="text-embedding-3-small",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
dimensions=5,
) )
response = dict(response) response = dict(response)

View file

@ -33,7 +33,7 @@ from litellm.proxy.proxy_server import (
) )
from litellm.proxy._types import NewUserRequest, DynamoDBArgs, GenerateKeyRequest from litellm.proxy._types import NewUserRequest, DynamoDBArgs, GenerateKeyRequest
from litellm.proxy.utils import DBClient from litellm.proxy.utils import DBClient, hash_token
from starlette.datastructures import URL from starlette.datastructures import URL
@ -232,7 +232,7 @@ def test_call_with_user_over_budget(custom_db_client):
"stream": False, "stream": False,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -305,7 +305,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
"complete_streaming_response": resp, "complete_streaming_response": resp,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -376,7 +376,7 @@ def test_call_with_user_key_budget(custom_db_client):
"stream": False, "stream": False,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -449,7 +449,7 @@ def test_call_with_key_over_budget_stream(custom_db_client):
"complete_streaming_response": resp, "complete_streaming_response": resp,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },

View file

@ -12,6 +12,8 @@
# 11. Generate a Key, cal key/info, call key/update, call key/info # 11. Generate a Key, cal key/info, call key/update, call key/info
# 12. Make a call with key over budget, expect to fail # 12. Make a call with key over budget, expect to fail
# 14. Make a streaming chat/completions call with key over budget, expect to fail # 14. Make a streaming chat/completions call with key over budget, expect to fail
# 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null
# 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk...<last-4-digits>
# function to call to generate key - async def new_user(data: NewUserRequest): # function to call to generate key - async def new_user(data: NewUserRequest):
@ -46,7 +48,7 @@ from litellm.proxy.proxy_server import (
spend_key_fn, spend_key_fn,
view_spend_logs, view_spend_logs,
) )
from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
verbose_proxy_logger.setLevel(level=logging.DEBUG) verbose_proxy_logger.setLevel(level=logging.DEBUG)
@ -86,6 +88,7 @@ def prisma_client():
litellm.proxy.proxy_server.litellm_proxy_budget_name = ( litellm.proxy.proxy_server.litellm_proxy_budget_name = (
f"litellm-proxy-budget-{time.time()}" f"litellm-proxy-budget-{time.time()}"
) )
litellm.proxy.proxy_server.user_custom_key_generate = None
return prisma_client return prisma_client
@ -918,7 +921,7 @@ def test_call_with_key_over_budget(prisma_client):
"stream": False, "stream": False,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -1009,7 +1012,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
"stream": False, "stream": False,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -1083,7 +1086,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
"complete_streaming_response": resp, "complete_streaming_response": resp,
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
"user_api_key": generated_key, "user_api_key": hash_token(generated_key),
"user_api_key_user_id": user_id, "user_api_key_user_id": user_id,
} }
}, },
@ -1140,3 +1143,48 @@ async def test_view_spend_per_key(prisma_client):
except Exception as e: except Exception as e:
print("Got Exception", e) print("Got Exception", e)
pytest.fail(f"Got exception {e}") pytest.fail(f"Got exception {e}")
@pytest.mark.asyncio()
async def test_key_name_null(prisma_client):
"""
- create key
- get key info
- assert key_name is null
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert result["info"]["key_name"] is None
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")
@pytest.mark.asyncio()
async def test_key_name_set(prisma_client):
"""
- create key
- get key info
- assert key_name is not null
"""
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
await litellm.proxy.proxy_server.prisma_client.connect()
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
generated_key = key.key
result = await info_key_fn(key=generated_key)
print("result from info_key_fn", result)
assert isinstance(result["info"]["key_name"], str)
except Exception as e:
print("Got Exception", e)
pytest.fail(f"Got exception {e}")

View file

@ -32,7 +32,7 @@ from litellm.proxy.proxy_server import (
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
# Your bearer token # Your bearer token
token = "" token = "sk-1234"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}

View file

@ -31,7 +31,7 @@ from litellm.proxy.proxy_server import (
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
# Your bearer token # Your bearer token
token = "" token = "sk-1234"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}

View file

@ -33,7 +33,7 @@ from litellm.proxy.proxy_server import (
) # Replace with the actual module where your FastAPI router is defined ) # Replace with the actual module where your FastAPI router is defined
# Your bearer token # Your bearer token
token = "" token = "sk-1234"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}

View file

@ -714,6 +714,7 @@ class ImageResponse(OpenAIObject):
############################################################ ############################################################
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
verbose_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except:
@ -2029,14 +2030,15 @@ def client(original_function):
start_time=start_time, start_time=start_time,
) )
## check if metadata is passed in ## check if metadata is passed in
litellm_params = {}
if "metadata" in kwargs: if "metadata" in kwargs:
litellm_params = {"metadata": kwargs["metadata"]} litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables( logging_obj.update_environment_variables(
model=model, model=model,
user="", user="",
optional_params={}, optional_params={},
litellm_params=litellm_params, litellm_params=litellm_params,
) )
return logging_obj return logging_obj
except Exception as e: except Exception as e:
import logging import logging
@ -2900,6 +2902,7 @@ def cost_per_token(
completion_tokens=0, completion_tokens=0,
response_time_ms=None, response_time_ms=None,
custom_llm_provider=None, custom_llm_provider=None,
region_name=None,
): ):
""" """
Calculates the cost per token for a given model, prompt tokens, and completion tokens. Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -2916,16 +2919,46 @@ def cost_per_token(
prompt_tokens_cost_usd_dollar = 0 prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost model_cost_ref = litellm.model_cost
model_with_provider = model
if custom_llm_provider is not None: if custom_llm_provider is not None:
model_with_provider = custom_llm_provider + "/" + model model_with_provider = custom_llm_provider + "/" + model
else: if region_name is not None:
model_with_provider = model model_with_provider_and_region = (
f"{custom_llm_provider}/{region_name}/{model}"
)
if (
model_with_provider_and_region in model_cost_ref
): # use region based pricing, if it's available
model_with_provider = model_with_provider_and_region
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
verbose_logger.debug(f"Looking up model={model} in model_cost_map") print_verbose(f"Looking up model={model} in model_cost_map")
if model_with_provider in model_cost_ref:
print_verbose(
f"Success: model={model_with_provider} in model_cost_map - {model_cost_ref[model_with_provider]}"
)
print_verbose(
f"applying cost={model_cost_ref[model_with_provider].get('input_cost_per_token', None)} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
print_verbose(
f"calculated prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}"
)
print_verbose(
f"applying cost={model_cost_ref[model_with_provider].get('output_cost_per_token', None)} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
print_verbose(
f"calculated completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
if model in model_cost_ref: if model in model_cost_ref:
verbose_logger.debug(f"Success: model={model} in model_cost_map") print_verbose(f"Success: model={model} in model_cost_map")
verbose_logger.debug( print_verbose(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}" f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
) )
if ( if (
@ -2943,7 +2976,7 @@ def cost_per_token(
model_cost_ref[model].get("input_cost_per_second", None) is not None model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None and response_time_ms is not None
): ):
verbose_logger.debug( print_verbose(
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}" f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
) )
## COST PER SECOND ## ## COST PER SECOND ##
@ -2951,30 +2984,12 @@ def cost_per_token(
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000 model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
) )
completion_tokens_cost_usd_dollar = 0.0 completion_tokens_cost_usd_dollar = 0.0
verbose_logger.debug( print_verbose(
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}" f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref:
verbose_logger.debug(
f"Looking up model={model_with_provider} in model_cost_map"
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model: elif "ft:gpt-3.5-turbo" in model:
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
@ -3031,7 +3046,10 @@ def completion_cost(
prompt="", prompt="",
messages: List = [], messages: List = [],
completion="", completion="",
total_time=0.0, # used for replicate total_time=0.0, # used for replicate, sagemaker
### REGION ###
custom_llm_provider=None,
region_name=None, # used for bedrock pricing
### IMAGE GEN ### ### IMAGE GEN ###
size=None, size=None,
quality=None, quality=None,
@ -3080,12 +3098,13 @@ def completion_cost(
model = ( model = (
model or completion_response["model"] model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
if completion_response is not None and hasattr( if hasattr(completion_response, "_hidden_params"):
completion_response, "_hidden_params"
):
custom_llm_provider = completion_response._hidden_params.get( custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", "" "custom_llm_provider", ""
) )
region_name = completion_response._hidden_params.get(
"region_name", region_name
)
else: else:
if len(messages) > 0: if len(messages) > 0:
prompt_tokens = token_counter(model=model, messages=messages) prompt_tokens = token_counter(model=model, messages=messages)
@ -3146,8 +3165,13 @@ def completion_cost(
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
response_time_ms=total_time, response_time_ms=total_time,
region_name=region_name,
) )
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar _final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
print_verbose(
f"final cost: {_final_cost}; prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}; completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
)
return _final_cost
except Exception as e: except Exception as e:
raise e raise e
@ -3313,8 +3337,10 @@ def get_optional_params_image_gen(
def get_optional_params_embeddings( def get_optional_params_embeddings(
# 2 optional params # 2 optional params
model=None,
user=None, user=None,
encoding_format=None, encoding_format=None,
dimensions=None,
custom_llm_provider="", custom_llm_provider="",
**kwargs, **kwargs,
): ):
@ -3325,7 +3351,7 @@ def get_optional_params_embeddings(
for k, v in special_params.items(): for k, v in special_params.items():
passed_params[k] = v passed_params[k] = v
default_params = {"user": None, "encoding_format": None} default_params = {"user": None, "encoding_format": None, "dimensions": None}
non_default_params = { non_default_params = {
k: v k: v
@ -3333,6 +3359,19 @@ def get_optional_params_embeddings(
if (k in default_params and v != default_params[k]) if (k in default_params and v != default_params[k])
} }
## raise exception if non-default value passed for non-openai/azure embedding calls ## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider == "openai":
# 'dimensions` is only supported in `text-embedding-3` and later models
if (
model is not None
and "text-embedding-3" not in model
and "dimensions" in non_default_params.keys()
):
raise UnsupportedParamsError(
status_code=500,
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
)
if ( if (
custom_llm_provider != "openai" custom_llm_provider != "openai"
and custom_llm_provider != "azure" and custom_llm_provider != "azure"

6
poetry.lock generated
View file

@ -1158,13 +1158,13 @@ files = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.8.0" version = "1.10.0"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
optional = false optional = false
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.8.0-py3-none-any.whl", hash = "sha256:0f8f53805826103fdd8adaf379ad3ec23f9d867e698cbc14caf34b778d150175"}, {file = "openai-1.10.0-py3-none-any.whl", hash = "sha256:aa69e97d0223ace9835fbf9c997abe9ee95318f684fd2de6d02c870700c71ebc"},
{file = "openai-1.8.0.tar.gz", hash = "sha256:93366be27802f517e89328801913d2a5ede45e3b86fdcab420385b8a1b88c767"}, {file = "openai-1.10.0.tar.gz", hash = "sha256:208886cb501b930dc63f48d51db9c15e5380380f80516d07332adad67c9f1053"},
] ]
[package.dependencies] [package.dependencies]

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.19.4" version = "1.20.0"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.19.4" version = "1.20.0"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -25,6 +25,8 @@ model LiteLLM_UserTable {
// Generate Tokens for Proxy // Generate Tokens for Proxy
model LiteLLM_VerificationToken { model LiteLLM_VerificationToken {
token String @unique token String @unique
key_name String?
key_alias String?
spend Float @default(0.0) spend Float @default(0.0)
expires DateTime? expires DateTime?
models String[] models String[]
@ -53,12 +55,13 @@ model LiteLLM_SpendLogs {
call_type String call_type String
api_key String @default ("") api_key String @default ("")
spend Float @default(0.0) spend Float @default(0.0)
total_tokens Int @default(0)
prompt_tokens Int @default(0)
completion_tokens Int @default(0)
startTime DateTime // Assuming start_time is a DateTime field startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field endTime DateTime // Assuming end_time is a DateTime field
model String @default("") model String @default("")
user String @default("") user String @default("")
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
usage Json @default("{}")
metadata Json @default("{}") metadata Json @default("{}")
cache_hit String @default("") cache_hit String @default("")
cache_key String @default("") cache_key String @default("")

View file

@ -115,7 +115,9 @@ async def chat_completion(session, key, model="gpt-4"):
print() print()
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(
f"Request did not return a 200 status code: {status}. Response: {response_text}"
)
return await response.json() return await response.json()
@ -201,11 +203,14 @@ async def test_key_delete():
) )
async def get_key_info(session, get_key, call_key): async def get_key_info(session, call_key, get_key=None):
""" """
Make sure only models user has access to are returned Make sure only models user has access to are returned
""" """
url = f"http://0.0.0.0:4000/key/info?key={get_key}" if get_key is None:
url = "http://0.0.0.0:4000/key/info"
else:
url = f"http://0.0.0.0:4000/key/info?key={get_key}"
headers = { headers = {
"Authorization": f"Bearer {call_key}", "Authorization": f"Bearer {call_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
@ -241,6 +246,9 @@ async def test_key_info():
await get_key_info(session=session, get_key=key, call_key="sk-1234") await get_key_info(session=session, get_key=key, call_key="sk-1234")
# as key itself # # as key itself #
await get_key_info(session=session, get_key=key, call_key=key) await get_key_info(session=session, get_key=key, call_key=key)
# as key itself, use the auth param, and no query key needed
await get_key_info(session=session, call_key=key)
# as random key # # as random key #
key_gen = await generate_key(session=session, i=0) key_gen = await generate_key(session=session, i=0)
random_key = key_gen["key"] random_key = key_gen["key"]
@ -281,14 +289,20 @@ async def test_key_info_spend_values():
await asyncio.sleep(5) await asyncio.sleep(5)
spend_logs = await get_spend_logs(session=session, request_id=response["id"]) spend_logs = await get_spend_logs(session=session, request_id=response["id"])
print(f"spend_logs: {spend_logs}") print(f"spend_logs: {spend_logs}")
usage = spend_logs[0]["usage"] completion_tokens = spend_logs[0]["completion_tokens"]
prompt_tokens = spend_logs[0]["prompt_tokens"]
print(f"prompt_tokens: {prompt_tokens}; completion_tokens: {completion_tokens}")
litellm.set_verbose = True
prompt_cost, completion_cost = litellm.cost_per_token( prompt_cost, completion_cost = litellm.cost_per_token(
model="gpt-35-turbo", model="gpt-35-turbo",
prompt_tokens=usage["prompt_tokens"], prompt_tokens=prompt_tokens,
completion_tokens=usage["completion_tokens"], completion_tokens=completion_tokens,
custom_llm_provider="azure", custom_llm_provider="azure",
) )
print("prompt_cost: ", prompt_cost, "completion_cost: ", completion_cost)
response_cost = prompt_cost + completion_cost response_cost = prompt_cost + completion_cost
print(f"response_cost: {response_cost}")
await asyncio.sleep(5) # allow db log to be updated await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(session=session, get_key=key, call_key=key) key_info = await get_key_info(session=session, get_key=key, call_key=key)
print( print(
@ -380,3 +394,31 @@ async def test_key_with_budgets():
key_info = await get_key_info(session=session, get_key=key, call_key=key) key_info = await get_key_info(session=session, get_key=key, call_key=key)
reset_at_new_value = key_info["info"]["budget_reset_at"] reset_at_new_value = key_info["info"]["budget_reset_at"]
assert reset_at_init_value != reset_at_new_value assert reset_at_init_value != reset_at_new_value
@pytest.mark.asyncio
async def test_key_crossing_budget():
"""
- Create key with budget with budget=0.00000001
- make a /chat/completions call
- wait 5s
- make a /chat/completions call - should fail with key crossed it's budget
- Check if value updated
"""
from litellm.proxy.utils import hash_token
async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, i=0, budget=0.0000001)
key = key_gen["key"]
hashed_token = hash_token(token=key)
print(f"hashed_token: {hashed_token}")
response = await chat_completion(session=session, key=key)
print("response 1: ", response)
await asyncio.sleep(2)
try:
response = await chat_completion(session=session, key=key)
pytest.fail("Should have failed - Key crossed it's budget")
except Exception as e:
assert "ExceededTokenBudget: Current spend for token:" in str(e)