forked from phoenix/litellm-mirror
Merge remote-tracking branch 'src/main'
This commit is contained in:
commit
4dd18b553a
29 changed files with 550 additions and 170 deletions
7
.github/workflows/ghcr_deploy.yml
vendored
7
.github/workflows/ghcr_deploy.yml
vendored
|
@ -34,13 +34,6 @@ jobs:
|
||||||
with:
|
with:
|
||||||
push: true
|
push: true
|
||||||
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
|
tags: litellm/litellm:${{ github.event.inputs.tag || 'latest' }}
|
||||||
-
|
|
||||||
name: Build and push litellm-ui image
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
push: true
|
|
||||||
file: ui/Dockerfile
|
|
||||||
tags: litellm/litellm-ui:${{ github.event.inputs.tag || 'latest' }}
|
|
||||||
-
|
-
|
||||||
name: Build and push litellm-database image
|
name: Build and push litellm-database image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
|
|
|
@ -13,8 +13,8 @@ response = embedding(model='text-embedding-ada-002', input=["good morning from l
|
||||||
|
|
||||||
- `model`: *string* - ID of the model to use. `model='text-embedding-ada-002'`
|
- `model`: *string* - ID of the model to use. `model='text-embedding-ada-002'`
|
||||||
|
|
||||||
- `input`: *array* - Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less.
|
- `input`: *string or array* - Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for text-embedding-ada-002), cannot be an empty string, and any array must be 2048 dimensions or less.
|
||||||
```
|
```python
|
||||||
input=["good morning from litellm"]
|
input=["good morning from litellm"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -22,7 +22,11 @@ input=["good morning from litellm"]
|
||||||
|
|
||||||
- `user`: *string (optional)* A unique identifier representing your end-user,
|
- `user`: *string (optional)* A unique identifier representing your end-user,
|
||||||
|
|
||||||
- `timeout`: *integer* - The maximum time, in seconds, to wait for the API to respond. Defaults to 600 seconds (10 minutes).
|
- `dimensions`: *integer (Optional)* The number of dimensions the resulting output embeddings should have. Only supported in OpenAI/Azure text-embedding-3 and later models.
|
||||||
|
|
||||||
|
- `encoding_format`: *string (Optional)* The format to return the embeddings in. Can be either `"float"` or `"base64"`. Defaults to `encoding_format="float"`
|
||||||
|
|
||||||
|
- `timeout`: *integer (Optional)* - The maximum time, in seconds, to wait for the API to respond. Defaults to 600 seconds (10 minutes).
|
||||||
|
|
||||||
- `api_base`: *string (optional)* - The api endpoint you want to call the model with
|
- `api_base`: *string (optional)* - The api endpoint you want to call the model with
|
||||||
|
|
||||||
|
@ -66,7 +70,12 @@ input=["good morning from litellm"]
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
import os
|
import os
|
||||||
os.environ['OPENAI_API_KEY'] = ""
|
os.environ['OPENAI_API_KEY'] = ""
|
||||||
response = embedding('text-embedding-ada-002', input=["good morning from litellm"])
|
response = embedding(
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
input=["good morning from litellm", "this is another item"],
|
||||||
|
metadata={"anything": "good day"},
|
||||||
|
dimensions=5 # Only supported in text-embedding-3 and later models.
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
| Model Name | Function Call | Required OS Variables |
|
| Model Name | Function Call | Required OS Variables |
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
# Slack Alerting
|
# Slack Alerting
|
||||||
|
|
||||||
Get alerts for failed db read/writes, hanging api calls, failed api calls.
|
Get alerts for:
|
||||||
|
- hanging LLM api calls
|
||||||
|
- failed LLM api calls
|
||||||
|
- slow LLM api calls
|
||||||
|
- budget Tracking per key/user:
|
||||||
|
- When a User/Key crosses their Budget
|
||||||
|
- When a User/Key is 15% away from crossing their Budget
|
||||||
|
- failed db read/writes
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
|
|
@ -605,6 +605,49 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Custom Callbacks - Track API Key, API Endpoint, Model Used
|
||||||
|
|
||||||
|
If you need to track the api_key, api endpoint, model, custom_llm_provider used for each completion call, you can setup a [custom callback](https://docs.litellm.ai/docs/observability/custom_callback)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
print("kwargs=", kwargs)
|
||||||
|
litellm_params= kwargs.get("litellm_params")
|
||||||
|
api_key = litellm_params.get("api_key")
|
||||||
|
api_base = litellm_params.get("api_base")
|
||||||
|
custom_llm_provider= litellm_params.get("custom_llm_provider")
|
||||||
|
response_cost = kwargs.get("response_cost")
|
||||||
|
|
||||||
|
# print the values
|
||||||
|
print("api_key=", api_key)
|
||||||
|
print("api_base=", api_base)
|
||||||
|
print("custom_llm_provider=", custom_llm_provider)
|
||||||
|
print("response_cost=", response_cost)
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Failure")
|
||||||
|
print("kwargs=")
|
||||||
|
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
# Init Router
|
||||||
|
router = Router(model_list=model_list, routing_strategy="simple-shuffle")
|
||||||
|
|
||||||
|
# router completion call
|
||||||
|
response = router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{ "role": "user", "content": "Hi who are you"}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Deploy Router
|
## Deploy Router
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,12 @@
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
# | |
|
||||||
|
# | NOT PROXY BUDGET MANAGER |
|
||||||
|
# | proxy budget manager is in proxy_server.py |
|
||||||
|
# | |
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
#
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import os, json, time
|
import os, json, time
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
@ -16,7 +25,7 @@ class BudgetManager:
|
||||||
self.client_type = client_type
|
self.client_type = client_type
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.api_base = api_base or "https://api.litellm.ai"
|
self.api_base = api_base or "https://api.litellm.ai"
|
||||||
self.headers = headers or {'Content-Type': 'application/json'}
|
self.headers = headers or {"Content-Type": "application/json"}
|
||||||
## load the data or init the initial dictionaries
|
## load the data or init the initial dictionaries
|
||||||
self.load_data()
|
self.load_data()
|
||||||
|
|
||||||
|
|
|
@ -659,9 +659,16 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = response_metadata.get(
|
||||||
completion_tokens = len(
|
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
)
|
||||||
|
completion_tokens = response_metadata.get(
|
||||||
|
"x-amzn-bedrock-output-token-count",
|
||||||
|
len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
|
@ -672,6 +679,8 @@ def completion(
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
model_response._hidden_params["region_name"] = client.meta.region_name
|
||||||
|
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
|
||||||
return model_response
|
return model_response
|
||||||
except BedrockError as e:
|
except BedrockError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
|
@ -718,8 +718,22 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise OpenAIError(status_code=e.status_code, message=str(e))
|
raise OpenAIError(status_code=e.status_code, message=str(e))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -586,6 +585,10 @@ def completion(
|
||||||
)
|
)
|
||||||
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
||||||
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||||
|
model_response._hidden_params["region_name"] = kwargs.get(
|
||||||
|
"aws_region_name", None
|
||||||
|
) # support region-based pricing for bedrock
|
||||||
|
|
||||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
litellm.register_model(
|
litellm.register_model(
|
||||||
|
@ -2224,6 +2227,7 @@ def embedding(
|
||||||
model,
|
model,
|
||||||
input=[],
|
input=[],
|
||||||
# Optional params
|
# Optional params
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
timeout=600, # default to 10 minutes
|
timeout=600, # default to 10 minutes
|
||||||
# set api_base, api_version, api_key
|
# set api_base, api_version, api_key
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
@ -2244,6 +2248,7 @@ def embedding(
|
||||||
Parameters:
|
Parameters:
|
||||||
- model: The embedding model to use.
|
- model: The embedding model to use.
|
||||||
- input: The input for which embeddings are to be generated.
|
- input: The input for which embeddings are to be generated.
|
||||||
|
- dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
|
||||||
- timeout: The timeout value for the API call, default 10 mins
|
- timeout: The timeout value for the API call, default 10 mins
|
||||||
- litellm_call_id: The call ID for litellm logging.
|
- litellm_call_id: The call ID for litellm logging.
|
||||||
- litellm_logging_obj: The litellm logging object.
|
- litellm_logging_obj: The litellm logging object.
|
||||||
|
@ -2277,6 +2282,7 @@ def embedding(
|
||||||
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"user",
|
"user",
|
||||||
|
"dimensions",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
"api_base",
|
"api_base",
|
||||||
"api_version",
|
"api_version",
|
||||||
|
@ -2345,7 +2351,9 @@ def embedding(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
|
model=model,
|
||||||
user=user,
|
user=user,
|
||||||
|
dimensions=dimensions,
|
||||||
encoding_format=encoding_format,
|
encoding_format=encoding_format,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
|
@ -3067,7 +3075,7 @@ def image_generation(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
logging = litellm_logging_obj
|
logging: Logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
user=user,
|
user=user,
|
||||||
|
|
|
@ -140,6 +140,7 @@ class GenerateRequestBase(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyRequest(GenerateRequestBase):
|
class GenerateKeyRequest(GenerateRequestBase):
|
||||||
|
key_alias: Optional[str] = None
|
||||||
duration: Optional[str] = "1h"
|
duration: Optional[str] = "1h"
|
||||||
aliases: Optional[dict] = {}
|
aliases: Optional[dict] = {}
|
||||||
config: Optional[dict] = {}
|
config: Optional[dict] = {}
|
||||||
|
@ -304,6 +305,8 @@ class ConfigYAML(LiteLLMBase):
|
||||||
|
|
||||||
class LiteLLM_VerificationToken(LiteLLMBase):
|
class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
token: str
|
token: str
|
||||||
|
key_name: Optional[str] = None
|
||||||
|
key_alias: Optional[str] = None
|
||||||
spend: float = 0.0
|
spend: float = 0.0
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
expires: Union[str, None]
|
expires: Union[str, None]
|
||||||
|
@ -346,11 +349,12 @@ class LiteLLM_SpendLogs(LiteLLMBase):
|
||||||
model: Optional[str] = ""
|
model: Optional[str] = ""
|
||||||
call_type: str
|
call_type: str
|
||||||
spend: Optional[float] = 0.0
|
spend: Optional[float] = 0.0
|
||||||
|
total_tokens: Optional[int] = 0
|
||||||
|
prompt_tokens: Optional[int] = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
startTime: Union[str, datetime, None]
|
startTime: Union[str, datetime, None]
|
||||||
endTime: Union[str, datetime, None]
|
endTime: Union[str, datetime, None]
|
||||||
user: Optional[str] = ""
|
user: Optional[str] = ""
|
||||||
modelParameters: Optional[Json] = {}
|
|
||||||
usage: Optional[Json] = {}
|
|
||||||
metadata: Optional[Json] = {}
|
metadata: Optional[Json] = {}
|
||||||
cache_hit: Optional[str] = "False"
|
cache_hit: Optional[str] = "False"
|
||||||
cache_key: Optional[str] = None
|
cache_key: Optional[str] = None
|
||||||
|
|
|
@ -5,6 +5,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_Config,
|
LiteLLM_Config,
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.utils import hash_token
|
||||||
from litellm import get_secret
|
from litellm import get_secret
|
||||||
from typing import Any, List, Literal, Optional, Union
|
from typing import Any, List, Literal, Optional, Union
|
||||||
import json
|
import json
|
||||||
|
@ -187,6 +188,8 @@ class DynamoDBWrapper(CustomDB):
|
||||||
table = client.table(self.database_arguments.spend_table_name)
|
table = client.table(self.database_arguments.spend_table_name)
|
||||||
|
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
|
if k == "token" and value[k].startswith("sk-"):
|
||||||
|
value[k] = hash_token(token=v)
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
value[k] = v.isoformat()
|
value[k] = v.isoformat()
|
||||||
|
|
||||||
|
@ -229,6 +232,10 @@ class DynamoDBWrapper(CustomDB):
|
||||||
table = client.table(self.database_arguments.config_table_name)
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
key_name = "param_name"
|
key_name = "param_name"
|
||||||
|
|
||||||
|
if key_name == "token" and key.startswith("sk-"):
|
||||||
|
# ensure it's hashed
|
||||||
|
key = hash_token(token=key)
|
||||||
|
|
||||||
response = await table.get_item({key_name: key})
|
response = await table.get_item({key_name: key})
|
||||||
|
|
||||||
new_response: Any = None
|
new_response: Any = None
|
||||||
|
@ -308,6 +315,8 @@ class DynamoDBWrapper(CustomDB):
|
||||||
# Convert datetime object to ISO8601 string
|
# Convert datetime object to ISO8601 string
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
v = v.isoformat()
|
v = v.isoformat()
|
||||||
|
if k == "token" and value[k].startswith("sk-"):
|
||||||
|
value[k] = hash_token(token=v)
|
||||||
|
|
||||||
# Accumulate updates
|
# Accumulate updates
|
||||||
actions.append((F(k), Value(value=v)))
|
actions.append((F(k), Value(value=v)))
|
||||||
|
|
|
@ -11,6 +11,12 @@ model_list:
|
||||||
output_cost_per_token: 0.00003
|
output_cost_per_token: 0.00003
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
base_model: gpt-3.5-turbo
|
base_model: gpt-3.5-turbo
|
||||||
|
- model_name: gpt-4
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
- model_name: gpt-vision
|
- model_name: gpt-vision
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-4-vision
|
model: azure/gpt-4-vision
|
||||||
|
@ -61,7 +67,7 @@ model_list:
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
|
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
|
||||||
success_callback: ['langfuse']
|
success_callback: ['langfuse']
|
||||||
max_budget: 0.025 # global budget for proxy
|
max_budget: 10 # global budget for proxy
|
||||||
budget_duration: 30d # global budget duration, will reset after 30d
|
budget_duration: 30d # global budget duration, will reset after 30d
|
||||||
# cache: True
|
# cache: True
|
||||||
# setting callback class
|
# setting callback class
|
||||||
|
|
|
@ -75,6 +75,7 @@ from litellm.proxy.utils import (
|
||||||
send_email,
|
send_email,
|
||||||
get_logging_payload,
|
get_logging_payload,
|
||||||
reset_budget,
|
reset_budget,
|
||||||
|
hash_token,
|
||||||
)
|
)
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
import pydantic
|
import pydantic
|
||||||
|
@ -243,6 +244,8 @@ async def user_api_key_auth(
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
|
if isinstance(api_key, str):
|
||||||
|
assert api_key.startswith("sk-") # prevent token hashes from being used
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(api_key=api_key)
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
|
@ -288,8 +291,9 @@ async def user_api_key_auth(
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
|
||||||
## check for cache hit (In-Memory Cache)
|
## check for cache hit (In-Memory Cache)
|
||||||
|
if api_key.startswith("sk-"):
|
||||||
|
api_key = hash_token(token=api_key)
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
verbose_proxy_logger.debug(f"valid_token from cache: {valid_token}")
|
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
|
@ -482,10 +486,10 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
# Add token to cache
|
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
|
||||||
|
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
|
# Add hashed token to cache
|
||||||
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
"""
|
"""
|
||||||
|
@ -520,7 +524,10 @@ async def user_api_key_auth(
|
||||||
# check if user can access this route
|
# check if user can access this route
|
||||||
query_params = request.query_params
|
query_params = request.query_params
|
||||||
key = query_params.get("key")
|
key = query_params.get("key")
|
||||||
if prisma_client.hash_token(token=key) != api_key:
|
if (
|
||||||
|
key is not None
|
||||||
|
and prisma_client.hash_token(token=key) != api_key
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="user not allowed to access this key's info",
|
detail="user not allowed to access this key's info",
|
||||||
|
@ -748,6 +755,9 @@ async def update_database(
|
||||||
|
|
||||||
### UPDATE KEY SPEND ###
|
### UPDATE KEY SPEND ###
|
||||||
async def _update_key_db():
|
async def _update_key_db():
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"adding spend to key db. Response cost: {response_cost}. Token: {token}."
|
||||||
|
)
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
# Fetch the existing cost for the given token
|
# Fetch the existing cost for the given token
|
||||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||||
|
@ -1239,6 +1249,7 @@ async def generate_key_helper_fn(
|
||||||
rpm_limit: Optional[int] = None,
|
rpm_limit: Optional[int] = None,
|
||||||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||||
update_key_values: Optional[dict] = None,
|
update_key_values: Optional[dict] = None,
|
||||||
|
key_alias: Optional[str] = None,
|
||||||
):
|
):
|
||||||
global prisma_client, custom_db_client
|
global prisma_client, custom_db_client
|
||||||
|
|
||||||
|
@ -1312,6 +1323,7 @@ async def generate_key_helper_fn(
|
||||||
}
|
}
|
||||||
key_data = {
|
key_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
"key_alias": key_alias,
|
||||||
"expires": expires,
|
"expires": expires,
|
||||||
"models": models,
|
"models": models,
|
||||||
"aliases": aliases_json,
|
"aliases": aliases_json,
|
||||||
|
@ -1327,6 +1339,8 @@ async def generate_key_helper_fn(
|
||||||
"budget_duration": key_budget_duration,
|
"budget_duration": key_budget_duration,
|
||||||
"budget_reset_at": key_reset_at,
|
"budget_reset_at": key_reset_at,
|
||||||
}
|
}
|
||||||
|
if general_settings.get("allow_user_auth", False) == True:
|
||||||
|
key_data["key_name"] = f"sk-...{token[-4:]}"
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
||||||
|
@ -2451,10 +2465,10 @@ async def delete_key_fn(data: DeleteKeyRequest):
|
||||||
Delete a key from the key management system.
|
Delete a key from the key management system.
|
||||||
|
|
||||||
Parameters::
|
Parameters::
|
||||||
- keys (List[str]): A list of keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw"]}
|
- keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw"]}
|
- deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
|
||||||
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -2491,14 +2505,39 @@ async def delete_key_fn(data: DeleteKeyRequest):
|
||||||
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
async def info_key_fn(
|
async def info_key_fn(
|
||||||
key: str = fastapi.Query(..., description="Key in the request parameters"),
|
key: Optional[str] = fastapi.Query(
|
||||||
|
default=None, description="Key in the request parameters"
|
||||||
|
),
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Retrieve information about a key.
|
||||||
|
Parameters:
|
||||||
|
key: Optional[str] = Query parameter representing the key in the request
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key
|
||||||
|
Returns:
|
||||||
|
Dict containing the key and its associated information
|
||||||
|
|
||||||
|
Example Curl:
|
||||||
|
```
|
||||||
|
curl -X GET "http://0.0.0.0:8000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \
|
||||||
|
-H "Authorization: Bearer sk-1234"
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Curl - if no key is passed, it will use the Key Passed in Authorization Header
|
||||||
|
```
|
||||||
|
curl -X GET "http://0.0.0.0:8000/key/info" \
|
||||||
|
-H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA"
|
||||||
|
```
|
||||||
|
"""
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
)
|
)
|
||||||
|
if key == None:
|
||||||
|
key = user_api_key_dict.api_key
|
||||||
key_info = await prisma_client.get_data(token=key)
|
key_info = await prisma_client.get_data(token=key)
|
||||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -7,6 +7,7 @@ generator client {
|
||||||
provider = "prisma-client-py"
|
provider = "prisma-client-py"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track spend, rate limit, budget Users
|
||||||
model LiteLLM_UserTable {
|
model LiteLLM_UserTable {
|
||||||
user_id String @unique
|
user_id String @unique
|
||||||
team_id String?
|
team_id String?
|
||||||
|
@ -21,9 +22,11 @@ model LiteLLM_UserTable {
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
}
|
}
|
||||||
|
|
||||||
// required for token gen
|
// Generate Tokens for Proxy
|
||||||
model LiteLLM_VerificationToken {
|
model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
|
key_name String?
|
||||||
|
key_alias String?
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
expires DateTime?
|
expires DateTime?
|
||||||
models String[]
|
models String[]
|
||||||
|
@ -40,22 +43,25 @@ model LiteLLM_VerificationToken {
|
||||||
budget_reset_at DateTime?
|
budget_reset_at DateTime?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// store proxy config.yaml
|
||||||
model LiteLLM_Config {
|
model LiteLLM_Config {
|
||||||
param_name String @id
|
param_name String @id
|
||||||
param_value Json?
|
param_value Json?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// View spend, model, api_key per request
|
||||||
model LiteLLM_SpendLogs {
|
model LiteLLM_SpendLogs {
|
||||||
request_id String @unique
|
request_id String @unique
|
||||||
call_type String
|
call_type String
|
||||||
api_key String @default ("")
|
api_key String @default ("")
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
|
total_tokens Int @default(0)
|
||||||
|
prompt_tokens Int @default(0)
|
||||||
|
completion_tokens Int @default(0)
|
||||||
startTime DateTime // Assuming start_time is a DateTime field
|
startTime DateTime // Assuming start_time is a DateTime field
|
||||||
endTime DateTime // Assuming end_time is a DateTime field
|
endTime DateTime // Assuming end_time is a DateTime field
|
||||||
model String @default("")
|
model String @default("")
|
||||||
user String @default("")
|
user String @default("")
|
||||||
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
|
|
||||||
usage Json @default("{}")
|
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
cache_hit String @default("")
|
cache_hit String @default("")
|
||||||
cache_key String @default("")
|
cache_key String @default("")
|
||||||
|
|
|
@ -198,7 +198,14 @@ class ProxyLogging:
|
||||||
max_budget = user_info["max_budget"]
|
max_budget = user_info["max_budget"]
|
||||||
spend = user_info["spend"]
|
spend = user_info["spend"]
|
||||||
user_email = user_info["user_email"]
|
user_email = user_info["user_email"]
|
||||||
user_info = f"""\nUser ID: {user_id}\nMax Budget: {max_budget}\nSpend: {spend}\nUser Email: {user_email}"""
|
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
|
||||||
|
elif type == "token_budget":
|
||||||
|
token_info = dict(user_info)
|
||||||
|
token = token_info["token"]
|
||||||
|
spend = token_info["spend"]
|
||||||
|
max_budget = token_info["max_budget"]
|
||||||
|
user_id = token_info["user_id"]
|
||||||
|
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
|
||||||
else:
|
else:
|
||||||
user_info = str(user_info)
|
user_info = str(user_info)
|
||||||
# percent of max_budget left to spend
|
# percent of max_budget left to spend
|
||||||
|
@ -814,7 +821,13 @@ class PrismaClient:
|
||||||
Allow user to delete a key(s)
|
Allow user to delete a key(s)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
hashed_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if isinstance(token, str) and token.startswith("sk-"):
|
||||||
|
hashed_token = self.hash_token(token=token)
|
||||||
|
else:
|
||||||
|
hashed_token = token
|
||||||
|
hashed_tokens.append(hashed_token)
|
||||||
await self.db.litellm_verificationtoken.delete_many(
|
await self.db.litellm_verificationtoken.delete_many(
|
||||||
where={"token": {"in": hashed_tokens}}
|
where={"token": {"in": hashed_tokens}}
|
||||||
)
|
)
|
||||||
|
@ -1060,10 +1073,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
metadata = (
|
metadata = (
|
||||||
litellm_params.get("metadata", {}) or {}
|
litellm_params.get("metadata", {}) or {}
|
||||||
) # if litellm_params['metadata'] == None
|
) # if litellm_params['metadata'] == None
|
||||||
optional_params = kwargs.get("optional_params", {})
|
|
||||||
call_type = kwargs.get("call_type", "litellm.completion")
|
call_type = kwargs.get("call_type", "litellm.completion")
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
usage = response_obj["usage"]
|
usage = response_obj["usage"]
|
||||||
|
if type(usage) == litellm.Usage:
|
||||||
|
usage = dict(usage)
|
||||||
id = response_obj.get("id", str(uuid.uuid4()))
|
id = response_obj.get("id", str(uuid.uuid4()))
|
||||||
api_key = metadata.get("user_api_key", "")
|
api_key = metadata.get("user_api_key", "")
|
||||||
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||||
|
@ -1091,10 +1105,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
"endTime": end_time,
|
"endTime": end_time,
|
||||||
"model": kwargs.get("model", ""),
|
"model": kwargs.get("model", ""),
|
||||||
"user": kwargs.get("user", ""),
|
"user": kwargs.get("user", ""),
|
||||||
"modelParameters": optional_params,
|
|
||||||
"usage": usage,
|
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"cache_key": cache_key,
|
"cache_key": cache_key,
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": usage.get("completion_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
json_fields = [
|
json_fields = [
|
||||||
|
@ -1119,8 +1134,6 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
payload[param] = payload[param].model_dump_json()
|
payload[param] = payload[param].model_dump_json()
|
||||||
if type(payload[param]) == litellm.EmbeddingResponse:
|
if type(payload[param]) == litellm.EmbeddingResponse:
|
||||||
payload[param] = payload[param].model_dump_json()
|
payload[param] = payload[param].model_dump_json()
|
||||||
elif type(payload[param]) == litellm.Usage:
|
|
||||||
payload[param] = payload[param].model_dump_json()
|
|
||||||
else:
|
else:
|
||||||
payload[param] = json.dumps(payload[param])
|
payload[param] = json.dumps(payload[param])
|
||||||
|
|
||||||
|
|
|
@ -723,8 +723,8 @@ def test_cache_override():
|
||||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
end_time - start_time > 0.1
|
end_time - start_time > 0.05
|
||||||
) # ensure 2nd response comes in over 0.1s. This should not be cached.
|
) # ensure 2nd response comes in over 0.05s. This should not be cached.
|
||||||
|
|
||||||
|
|
||||||
# test_cache_override()
|
# test_cache_override()
|
||||||
|
|
|
@ -124,7 +124,7 @@ def test_cost_azure_gpt_35():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
test_cost_azure_gpt_35()
|
# test_cost_azure_gpt_35()
|
||||||
|
|
||||||
|
|
||||||
def test_cost_azure_embedding():
|
def test_cost_azure_embedding():
|
||||||
|
@ -165,3 +165,71 @@ def test_cost_openai_image_gen():
|
||||||
model="dall-e-2", size="1024-x-1024", quality="standard", n=1
|
model="dall-e-2", size="1024-x-1024", quality="standard", n=1
|
||||||
)
|
)
|
||||||
assert cost == 0.019922944
|
assert cost == 0.019922944
|
||||||
|
|
||||||
|
|
||||||
|
def test_cost_bedrock_pricing():
|
||||||
|
"""
|
||||||
|
- get pricing specific to region for a model
|
||||||
|
"""
|
||||||
|
from litellm import ModelResponse, Choices, Message
|
||||||
|
from litellm.utils import Usage
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
input_tokens = litellm.token_counter(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
print(f"input_tokens: {input_tokens}")
|
||||||
|
output_tokens = litellm.token_counter(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
text="It's all going well",
|
||||||
|
count_response_tokens=True,
|
||||||
|
)
|
||||||
|
print(f"output_tokens: {output_tokens}")
|
||||||
|
resp = ModelResponse(
|
||||||
|
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac",
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason=None,
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="It's all going well",
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700775391,
|
||||||
|
model="anthropic.claude-instant-v1",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=output_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
resp._hidden_params = {
|
||||||
|
"custom_llm_provider": "bedrock",
|
||||||
|
"region_name": "ap-northeast-1",
|
||||||
|
}
|
||||||
|
|
||||||
|
cost = litellm.completion_cost(
|
||||||
|
model="anthropic.claude-instant-v1",
|
||||||
|
completion_response=resp,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
predicted_cost = input_tokens * 0.00000223 + 0.00000755 * output_tokens
|
||||||
|
assert cost == predicted_cost
|
||||||
|
|
||||||
|
|
||||||
|
def test_cost_bedrock_pricing_actual_calls():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model = "anthropic.claude-instant-v1"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
response = litellm.completion(model=model, messages=messages)
|
||||||
|
assert response._hidden_params["region_name"] is not None
|
||||||
|
cost = litellm.completion_cost(
|
||||||
|
completion_response=response,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
assert cost > 0
|
||||||
|
|
|
@ -53,9 +53,9 @@ model_list:
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_version: 2023-07-01-preview
|
api_version: 2023-07-01-preview
|
||||||
model: azure/azure-embedding-model
|
model: azure/azure-embedding-model
|
||||||
model_name: azure-embedding-model
|
|
||||||
model_info:
|
model_info:
|
||||||
mode: "embedding"
|
mode: embedding
|
||||||
|
model_name: azure-embedding-model
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
model_info:
|
model_info:
|
||||||
|
@ -80,43 +80,49 @@ model_list:
|
||||||
description: this is a test openai model
|
description: this is a test openai model
|
||||||
id: 9b1ef341-322c-410a-8992-903987fef439
|
id: 9b1ef341-322c-410a-8992-903987fef439
|
||||||
model_name: test_openai_models
|
model_name: test_openai_models
|
||||||
- model_name: amazon-embeddings
|
- litellm_params:
|
||||||
litellm_params:
|
model: bedrock/amazon.titan-embed-text-v1
|
||||||
model: "bedrock/amazon.titan-embed-text-v1"
|
|
||||||
model_info:
|
model_info:
|
||||||
mode: embedding
|
mode: embedding
|
||||||
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
model_name: amazon-embeddings
|
||||||
litellm_params:
|
- litellm_params:
|
||||||
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
||||||
model_info:
|
model_info:
|
||||||
mode: embedding
|
mode: embedding
|
||||||
- model_name: dall-e-3
|
model_name: GPT-J 6B - Sagemaker Text Embedding (Internal)
|
||||||
litellm_params:
|
- litellm_params:
|
||||||
model: dall-e-3
|
model: dall-e-3
|
||||||
model_info:
|
model_info:
|
||||||
mode: image_generation
|
mode: image_generation
|
||||||
- model_name: dall-e-3
|
model_name: dall-e-3
|
||||||
litellm_params:
|
- litellm_params:
|
||||||
model: "azure/dall-e-3-test"
|
api_base: os.environ/AZURE_SWEDEN_API_BASE
|
||||||
api_version: "2023-12-01-preview"
|
api_key: os.environ/AZURE_SWEDEN_API_KEY
|
||||||
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
|
api_version: 2023-12-01-preview
|
||||||
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
|
model: azure/dall-e-3-test
|
||||||
model_info:
|
model_info:
|
||||||
mode: image_generation
|
mode: image_generation
|
||||||
- model_name: dall-e-2
|
model_name: dall-e-3
|
||||||
litellm_params:
|
- litellm_params:
|
||||||
model: "azure/"
|
api_base: os.environ/AZURE_API_BASE
|
||||||
api_version: "2023-06-01-preview"
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: "os.environ/AZURE_API_BASE"
|
api_version: 2023-06-01-preview
|
||||||
api_key: "os.environ/AZURE_API_KEY"
|
model: azure/
|
||||||
model_info:
|
model_info:
|
||||||
mode: image_generation
|
mode: image_generation
|
||||||
- model_name: text-embedding-ada-002
|
model_name: dall-e-2
|
||||||
litellm_params:
|
- litellm_params:
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: 2023-07-01-preview
|
||||||
model: azure/azure-embedding-model
|
model: azure/azure-embedding-model
|
||||||
api_base: "os.environ/AZURE_API_BASE"
|
|
||||||
api_key: "os.environ/AZURE_API_KEY"
|
|
||||||
api_version: "2023-07-01-preview"
|
|
||||||
model_info:
|
model_info:
|
||||||
|
base_model: text-embedding-ada-002
|
||||||
mode: embedding
|
mode: embedding
|
||||||
base_model: text-embedding-ada-002
|
model_name: text-embedding-ada-002
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 34cb2419-7c63-44ae-a189-53f1d1ce5953
|
||||||
|
model_name: test_openai_models
|
||||||
|
|
|
@ -819,47 +819,49 @@ async def test_async_embedding_azure_caching():
|
||||||
# Image Generation
|
# Image Generation
|
||||||
|
|
||||||
|
|
||||||
# ## Test OpenAI + Sync
|
## Test OpenAI + Sync
|
||||||
# def test_image_generation_openai():
|
def test_image_generation_openai():
|
||||||
# try:
|
try:
|
||||||
# customHandler_success = CompletionCustomHandler()
|
customHandler_success = CompletionCustomHandler()
|
||||||
# customHandler_failure = CompletionCustomHandler()
|
customHandler_failure = CompletionCustomHandler()
|
||||||
# litellm.callbacks = [customHandler_success]
|
# litellm.callbacks = [customHandler_success]
|
||||||
|
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
|
|
||||||
# response = litellm.image_generation(
|
# response = litellm.image_generation(
|
||||||
# prompt="A cute baby sea otter", model="dall-e-3"
|
# prompt="A cute baby sea otter", model="dall-e-3"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# print(f"response: {response}")
|
# print(f"response: {response}")
|
||||||
# assert len(response.data) > 0
|
# assert len(response.data) > 0
|
||||||
|
|
||||||
# print(f"customHandler_success.errors: {customHandler_success.errors}")
|
# print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
# print(f"customHandler_success.states: {customHandler_success.states}")
|
# print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
# assert len(customHandler_success.errors) == 0
|
# assert len(customHandler_success.errors) == 0
|
||||||
# assert len(customHandler_success.states) == 3 # pre, post, success
|
# assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# # test failure callback
|
# test failure callback
|
||||||
# litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
# try:
|
try:
|
||||||
# response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
# prompt="A cute baby sea otter", model="dall-e-4"
|
prompt="A cute baby sea otter",
|
||||||
# )
|
model="dall-e-2",
|
||||||
# except:
|
api_key="my-bad-api-key",
|
||||||
# pass
|
)
|
||||||
# print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
except:
|
||||||
# print(f"customHandler_failure.states: {customHandler_failure.states}")
|
pass
|
||||||
# assert len(customHandler_failure.errors) == 0
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
# assert len(customHandler_failure.states) == 3 # pre, post, failure
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
# except litellm.RateLimitError as e:
|
assert len(customHandler_failure.errors) == 0
|
||||||
# pass
|
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||||
# except litellm.ContentPolicyViolationError:
|
except litellm.RateLimitError as e:
|
||||||
# pass # OpenAI randomly raises these errors - skip when they occur
|
pass
|
||||||
# except Exception as e:
|
except litellm.ContentPolicyViolationError:
|
||||||
# pytest.fail(f"An exception occurred - {str(e)}")
|
pass # OpenAI randomly raises these errors - skip when they occur
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_openai()
|
test_image_generation_openai()
|
||||||
## Test OpenAI + Async
|
## Test OpenAI + Async
|
||||||
|
|
||||||
## Test Azure + Sync
|
## Test Azure + Sync
|
||||||
|
|
|
@ -64,7 +64,9 @@ def test_openai_embedding_3():
|
||||||
model="text-embedding-3-small",
|
model="text-embedding-3-small",
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["good morning from litellm", "this is another item"],
|
||||||
metadata={"anything": "good day"},
|
metadata={"anything": "good day"},
|
||||||
|
dimensions=5,
|
||||||
)
|
)
|
||||||
|
print(f"response:", response)
|
||||||
litellm_response = dict(response)
|
litellm_response = dict(response)
|
||||||
litellm_response_keys = set(litellm_response.keys())
|
litellm_response_keys = set(litellm_response.keys())
|
||||||
litellm_response_keys.discard("_response_ms")
|
litellm_response_keys.discard("_response_ms")
|
||||||
|
@ -80,6 +82,7 @@ def test_openai_embedding_3():
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
model="text-embedding-3-small",
|
model="text-embedding-3-small",
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["good morning from litellm", "this is another item"],
|
||||||
|
dimensions=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = dict(response)
|
response = dict(response)
|
||||||
|
|
|
@ -33,7 +33,7 @@ from litellm.proxy.proxy_server import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from litellm.proxy._types import NewUserRequest, DynamoDBArgs, GenerateKeyRequest
|
from litellm.proxy._types import NewUserRequest, DynamoDBArgs, GenerateKeyRequest
|
||||||
from litellm.proxy.utils import DBClient
|
from litellm.proxy.utils import DBClient, hash_token
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ def test_call_with_user_over_budget(custom_db_client):
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -305,7 +305,7 @@ def test_call_with_user_over_budget_stream(custom_db_client):
|
||||||
"complete_streaming_response": resp,
|
"complete_streaming_response": resp,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -376,7 +376,7 @@ def test_call_with_user_key_budget(custom_db_client):
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -449,7 +449,7 @@ def test_call_with_key_over_budget_stream(custom_db_client):
|
||||||
"complete_streaming_response": resp,
|
"complete_streaming_response": resp,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# 11. Generate a Key, cal key/info, call key/update, call key/info
|
# 11. Generate a Key, cal key/info, call key/update, call key/info
|
||||||
# 12. Make a call with key over budget, expect to fail
|
# 12. Make a call with key over budget, expect to fail
|
||||||
# 14. Make a streaming chat/completions call with key over budget, expect to fail
|
# 14. Make a streaming chat/completions call with key over budget, expect to fail
|
||||||
|
# 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null
|
||||||
|
# 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk...<last-4-digits>
|
||||||
|
|
||||||
|
|
||||||
# function to call to generate key - async def new_user(data: NewUserRequest):
|
# function to call to generate key - async def new_user(data: NewUserRequest):
|
||||||
|
@ -46,7 +48,7 @@ from litellm.proxy.proxy_server import (
|
||||||
spend_key_fn,
|
spend_key_fn,
|
||||||
view_spend_logs,
|
view_spend_logs,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
verbose_proxy_logger.setLevel(level=logging.DEBUG)
|
||||||
|
@ -86,6 +88,7 @@ def prisma_client():
|
||||||
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
litellm.proxy.proxy_server.litellm_proxy_budget_name = (
|
||||||
f"litellm-proxy-budget-{time.time()}"
|
f"litellm-proxy-budget-{time.time()}"
|
||||||
)
|
)
|
||||||
|
litellm.proxy.proxy_server.user_custom_key_generate = None
|
||||||
|
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
@ -918,7 +921,7 @@ def test_call_with_key_over_budget(prisma_client):
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1009,7 +1012,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1083,7 +1086,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
||||||
"complete_streaming_response": resp,
|
"complete_streaming_response": resp,
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"user_api_key": generated_key,
|
"user_api_key": hash_token(generated_key),
|
||||||
"user_api_key_user_id": user_id,
|
"user_api_key_user_id": user_id,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1140,3 +1143,48 @@ async def test_view_spend_per_key(prisma_client):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Got Exception", e)
|
print("Got Exception", e)
|
||||||
pytest.fail(f"Got exception {e}")
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_key_name_null(prisma_client):
|
||||||
|
"""
|
||||||
|
- create key
|
||||||
|
- get key info
|
||||||
|
- assert key_name is null
|
||||||
|
"""
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
try:
|
||||||
|
request = GenerateKeyRequest()
|
||||||
|
key = await generate_key_fn(request)
|
||||||
|
generated_key = key.key
|
||||||
|
result = await info_key_fn(key=generated_key)
|
||||||
|
print("result from info_key_fn", result)
|
||||||
|
assert result["info"]["key_name"] is None
|
||||||
|
except Exception as e:
|
||||||
|
print("Got Exception", e)
|
||||||
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_key_name_set(prisma_client):
|
||||||
|
"""
|
||||||
|
- create key
|
||||||
|
- get key info
|
||||||
|
- assert key_name is not null
|
||||||
|
"""
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True})
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
try:
|
||||||
|
request = GenerateKeyRequest()
|
||||||
|
key = await generate_key_fn(request)
|
||||||
|
generated_key = key.key
|
||||||
|
result = await info_key_fn(key=generated_key)
|
||||||
|
print("result from info_key_fn", result)
|
||||||
|
assert isinstance(result["info"]["key_name"], str)
|
||||||
|
except Exception as e:
|
||||||
|
print("Got Exception", e)
|
||||||
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
|
@ -32,7 +32,7 @@ from litellm.proxy.proxy_server import (
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = ""
|
token = "sk-1234"
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ from litellm.proxy.proxy_server import (
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = ""
|
token = "sk-1234"
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ from litellm.proxy.proxy_server import (
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
) # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = ""
|
token = "sk-1234"
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
119
litellm/utils.py
119
litellm/utils.py
|
@ -714,6 +714,7 @@ class ImageResponse(OpenAIObject):
|
||||||
############################################################
|
############################################################
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
|
verbose_logger.debug(print_statement)
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except:
|
||||||
|
@ -2029,14 +2030,15 @@ def client(original_function):
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
)
|
)
|
||||||
## check if metadata is passed in
|
## check if metadata is passed in
|
||||||
|
litellm_params = {}
|
||||||
if "metadata" in kwargs:
|
if "metadata" in kwargs:
|
||||||
litellm_params = {"metadata": kwargs["metadata"]}
|
litellm_params["metadata"] = kwargs["metadata"]
|
||||||
logging_obj.update_environment_variables(
|
logging_obj.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
user="",
|
user="",
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
return logging_obj
|
return logging_obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
|
@ -2900,6 +2902,7 @@ def cost_per_token(
|
||||||
completion_tokens=0,
|
completion_tokens=0,
|
||||||
response_time_ms=None,
|
response_time_ms=None,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
|
region_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||||
|
@ -2916,16 +2919,46 @@ def cost_per_token(
|
||||||
prompt_tokens_cost_usd_dollar = 0
|
prompt_tokens_cost_usd_dollar = 0
|
||||||
completion_tokens_cost_usd_dollar = 0
|
completion_tokens_cost_usd_dollar = 0
|
||||||
model_cost_ref = litellm.model_cost
|
model_cost_ref = litellm.model_cost
|
||||||
|
model_with_provider = model
|
||||||
if custom_llm_provider is not None:
|
if custom_llm_provider is not None:
|
||||||
model_with_provider = custom_llm_provider + "/" + model
|
model_with_provider = custom_llm_provider + "/" + model
|
||||||
else:
|
if region_name is not None:
|
||||||
model_with_provider = model
|
model_with_provider_and_region = (
|
||||||
|
f"{custom_llm_provider}/{region_name}/{model}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
model_with_provider_and_region in model_cost_ref
|
||||||
|
): # use region based pricing, if it's available
|
||||||
|
model_with_provider = model_with_provider_and_region
|
||||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
verbose_logger.debug(f"Looking up model={model} in model_cost_map")
|
print_verbose(f"Looking up model={model} in model_cost_map")
|
||||||
|
if model_with_provider in model_cost_ref:
|
||||||
|
print_verbose(
|
||||||
|
f"Success: model={model_with_provider} in model_cost_map - {model_cost_ref[model_with_provider]}"
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"applying cost={model_cost_ref[model_with_provider].get('input_cost_per_token', None)} for prompt_tokens={prompt_tokens}"
|
||||||
|
)
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"calculated prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}"
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"applying cost={model_cost_ref[model_with_provider].get('output_cost_per_token', None)} for completion_tokens={completion_tokens}"
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
||||||
|
* completion_tokens
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"calculated completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
||||||
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
if model in model_cost_ref:
|
if model in model_cost_ref:
|
||||||
verbose_logger.debug(f"Success: model={model} in model_cost_map")
|
print_verbose(f"Success: model={model} in model_cost_map")
|
||||||
verbose_logger.debug(
|
print_verbose(
|
||||||
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
|
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
@ -2943,7 +2976,7 @@ def cost_per_token(
|
||||||
model_cost_ref[model].get("input_cost_per_second", None) is not None
|
model_cost_ref[model].get("input_cost_per_second", None) is not None
|
||||||
and response_time_ms is not None
|
and response_time_ms is not None
|
||||||
):
|
):
|
||||||
verbose_logger.debug(
|
print_verbose(
|
||||||
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
|
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
|
||||||
)
|
)
|
||||||
## COST PER SECOND ##
|
## COST PER SECOND ##
|
||||||
|
@ -2951,30 +2984,12 @@ def cost_per_token(
|
||||||
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
|
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
|
||||||
)
|
)
|
||||||
completion_tokens_cost_usd_dollar = 0.0
|
completion_tokens_cost_usd_dollar = 0.0
|
||||||
verbose_logger.debug(
|
print_verbose(
|
||||||
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model_with_provider in model_cost_ref:
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"Looking up model={model_with_provider} in model_cost_map"
|
|
||||||
)
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
|
||||||
)
|
|
||||||
prompt_tokens_cost_usd_dollar = (
|
|
||||||
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
|
||||||
)
|
|
||||||
verbose_logger.debug(
|
|
||||||
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
|
||||||
)
|
|
||||||
completion_tokens_cost_usd_dollar = (
|
|
||||||
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
|
||||||
* completion_tokens
|
|
||||||
)
|
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
|
||||||
elif "ft:gpt-3.5-turbo" in model:
|
elif "ft:gpt-3.5-turbo" in model:
|
||||||
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -3031,7 +3046,10 @@ def completion_cost(
|
||||||
prompt="",
|
prompt="",
|
||||||
messages: List = [],
|
messages: List = [],
|
||||||
completion="",
|
completion="",
|
||||||
total_time=0.0, # used for replicate
|
total_time=0.0, # used for replicate, sagemaker
|
||||||
|
### REGION ###
|
||||||
|
custom_llm_provider=None,
|
||||||
|
region_name=None, # used for bedrock pricing
|
||||||
### IMAGE GEN ###
|
### IMAGE GEN ###
|
||||||
size=None,
|
size=None,
|
||||||
quality=None,
|
quality=None,
|
||||||
|
@ -3080,12 +3098,13 @@ def completion_cost(
|
||||||
model = (
|
model = (
|
||||||
model or completion_response["model"]
|
model or completion_response["model"]
|
||||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
) # check if user passed an override for model, if it's none check completion_response['model']
|
||||||
if completion_response is not None and hasattr(
|
if hasattr(completion_response, "_hidden_params"):
|
||||||
completion_response, "_hidden_params"
|
|
||||||
):
|
|
||||||
custom_llm_provider = completion_response._hidden_params.get(
|
custom_llm_provider = completion_response._hidden_params.get(
|
||||||
"custom_llm_provider", ""
|
"custom_llm_provider", ""
|
||||||
)
|
)
|
||||||
|
region_name = completion_response._hidden_params.get(
|
||||||
|
"region_name", region_name
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if len(messages) > 0:
|
if len(messages) > 0:
|
||||||
prompt_tokens = token_counter(model=model, messages=messages)
|
prompt_tokens = token_counter(model=model, messages=messages)
|
||||||
|
@ -3146,8 +3165,13 @@ def completion_cost(
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
response_time_ms=total_time,
|
response_time_ms=total_time,
|
||||||
|
region_name=region_name,
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||||
|
print_verbose(
|
||||||
|
f"final cost: {_final_cost}; prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}; completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
|
||||||
|
)
|
||||||
|
return _final_cost
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -3313,8 +3337,10 @@ def get_optional_params_image_gen(
|
||||||
|
|
||||||
def get_optional_params_embeddings(
|
def get_optional_params_embeddings(
|
||||||
# 2 optional params
|
# 2 optional params
|
||||||
|
model=None,
|
||||||
user=None,
|
user=None,
|
||||||
encoding_format=None,
|
encoding_format=None,
|
||||||
|
dimensions=None,
|
||||||
custom_llm_provider="",
|
custom_llm_provider="",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -3325,7 +3351,7 @@ def get_optional_params_embeddings(
|
||||||
for k, v in special_params.items():
|
for k, v in special_params.items():
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
|
||||||
default_params = {"user": None, "encoding_format": None}
|
default_params = {"user": None, "encoding_format": None, "dimensions": None}
|
||||||
|
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -3333,6 +3359,19 @@ def get_optional_params_embeddings(
|
||||||
if (k in default_params and v != default_params[k])
|
if (k in default_params and v != default_params[k])
|
||||||
}
|
}
|
||||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||||
|
if custom_llm_provider == "openai":
|
||||||
|
# 'dimensions` is only supported in `text-embedding-3` and later models
|
||||||
|
|
||||||
|
if (
|
||||||
|
model is not None
|
||||||
|
and "text-embedding-3" not in model
|
||||||
|
and "dimensions" in non_default_params.keys()
|
||||||
|
):
|
||||||
|
raise UnsupportedParamsError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"Setting dimensions is not supported for OpenAI `text-embedding-3` and later models. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
custom_llm_provider != "openai"
|
custom_llm_provider != "openai"
|
||||||
and custom_llm_provider != "azure"
|
and custom_llm_provider != "azure"
|
||||||
|
|
6
poetry.lock
generated
6
poetry.lock
generated
|
@ -1158,13 +1158,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "1.8.0"
|
version = "1.10.0"
|
||||||
description = "The official Python library for the openai API"
|
description = "The official Python library for the openai API"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7.1"
|
python-versions = ">=3.7.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "openai-1.8.0-py3-none-any.whl", hash = "sha256:0f8f53805826103fdd8adaf379ad3ec23f9d867e698cbc14caf34b778d150175"},
|
{file = "openai-1.10.0-py3-none-any.whl", hash = "sha256:aa69e97d0223ace9835fbf9c997abe9ee95318f684fd2de6d02c870700c71ebc"},
|
||||||
{file = "openai-1.8.0.tar.gz", hash = "sha256:93366be27802f517e89328801913d2a5ede45e3b86fdcab420385b8a1b88c767"},
|
{file = "openai-1.10.0.tar.gz", hash = "sha256:208886cb501b930dc63f48d51db9c15e5380380f80516d07332adad67c9f1053"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.19.4"
|
version = "1.20.0"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.19.4"
|
version = "1.20.0"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -25,6 +25,8 @@ model LiteLLM_UserTable {
|
||||||
// Generate Tokens for Proxy
|
// Generate Tokens for Proxy
|
||||||
model LiteLLM_VerificationToken {
|
model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
|
key_name String?
|
||||||
|
key_alias String?
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
expires DateTime?
|
expires DateTime?
|
||||||
models String[]
|
models String[]
|
||||||
|
@ -53,12 +55,13 @@ model LiteLLM_SpendLogs {
|
||||||
call_type String
|
call_type String
|
||||||
api_key String @default ("")
|
api_key String @default ("")
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
|
total_tokens Int @default(0)
|
||||||
|
prompt_tokens Int @default(0)
|
||||||
|
completion_tokens Int @default(0)
|
||||||
startTime DateTime // Assuming start_time is a DateTime field
|
startTime DateTime // Assuming start_time is a DateTime field
|
||||||
endTime DateTime // Assuming end_time is a DateTime field
|
endTime DateTime // Assuming end_time is a DateTime field
|
||||||
model String @default("")
|
model String @default("")
|
||||||
user String @default("")
|
user String @default("")
|
||||||
modelParameters Json @default("{}")// Assuming optional_params is a JSON field
|
|
||||||
usage Json @default("{}")
|
|
||||||
metadata Json @default("{}")
|
metadata Json @default("{}")
|
||||||
cache_hit String @default("")
|
cache_hit String @default("")
|
||||||
cache_key String @default("")
|
cache_key String @default("")
|
||||||
|
|
|
@ -115,7 +115,9 @@ async def chat_completion(session, key, model="gpt-4"):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(
|
||||||
|
f"Request did not return a 200 status code: {status}. Response: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
@ -201,11 +203,14 @@ async def test_key_delete():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_key_info(session, get_key, call_key):
|
async def get_key_info(session, call_key, get_key=None):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
"""
|
"""
|
||||||
url = f"http://0.0.0.0:4000/key/info?key={get_key}"
|
if get_key is None:
|
||||||
|
url = "http://0.0.0.0:4000/key/info"
|
||||||
|
else:
|
||||||
|
url = f"http://0.0.0.0:4000/key/info?key={get_key}"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {call_key}",
|
"Authorization": f"Bearer {call_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -241,6 +246,9 @@ async def test_key_info():
|
||||||
await get_key_info(session=session, get_key=key, call_key="sk-1234")
|
await get_key_info(session=session, get_key=key, call_key="sk-1234")
|
||||||
# as key itself #
|
# as key itself #
|
||||||
await get_key_info(session=session, get_key=key, call_key=key)
|
await get_key_info(session=session, get_key=key, call_key=key)
|
||||||
|
|
||||||
|
# as key itself, use the auth param, and no query key needed
|
||||||
|
await get_key_info(session=session, call_key=key)
|
||||||
# as random key #
|
# as random key #
|
||||||
key_gen = await generate_key(session=session, i=0)
|
key_gen = await generate_key(session=session, i=0)
|
||||||
random_key = key_gen["key"]
|
random_key = key_gen["key"]
|
||||||
|
@ -281,14 +289,20 @@ async def test_key_info_spend_values():
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
spend_logs = await get_spend_logs(session=session, request_id=response["id"])
|
spend_logs = await get_spend_logs(session=session, request_id=response["id"])
|
||||||
print(f"spend_logs: {spend_logs}")
|
print(f"spend_logs: {spend_logs}")
|
||||||
usage = spend_logs[0]["usage"]
|
completion_tokens = spend_logs[0]["completion_tokens"]
|
||||||
|
prompt_tokens = spend_logs[0]["prompt_tokens"]
|
||||||
|
print(f"prompt_tokens: {prompt_tokens}; completion_tokens: {completion_tokens}")
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||||
model="gpt-35-turbo",
|
model="gpt-35-turbo",
|
||||||
prompt_tokens=usage["prompt_tokens"],
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=usage["completion_tokens"],
|
completion_tokens=completion_tokens,
|
||||||
custom_llm_provider="azure",
|
custom_llm_provider="azure",
|
||||||
)
|
)
|
||||||
|
print("prompt_cost: ", prompt_cost, "completion_cost: ", completion_cost)
|
||||||
response_cost = prompt_cost + completion_cost
|
response_cost = prompt_cost + completion_cost
|
||||||
|
print(f"response_cost: {response_cost}")
|
||||||
await asyncio.sleep(5) # allow db log to be updated
|
await asyncio.sleep(5) # allow db log to be updated
|
||||||
key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
||||||
print(
|
print(
|
||||||
|
@ -380,3 +394,31 @@ async def test_key_with_budgets():
|
||||||
key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
||||||
reset_at_new_value = key_info["info"]["budget_reset_at"]
|
reset_at_new_value = key_info["info"]["budget_reset_at"]
|
||||||
assert reset_at_init_value != reset_at_new_value
|
assert reset_at_init_value != reset_at_new_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_crossing_budget():
|
||||||
|
"""
|
||||||
|
- Create key with budget with budget=0.00000001
|
||||||
|
- make a /chat/completions call
|
||||||
|
- wait 5s
|
||||||
|
- make a /chat/completions call - should fail with key crossed it's budget
|
||||||
|
|
||||||
|
- Check if value updated
|
||||||
|
"""
|
||||||
|
from litellm.proxy.utils import hash_token
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
key_gen = await generate_key(session=session, i=0, budget=0.0000001)
|
||||||
|
key = key_gen["key"]
|
||||||
|
hashed_token = hash_token(token=key)
|
||||||
|
print(f"hashed_token: {hashed_token}")
|
||||||
|
|
||||||
|
response = await chat_completion(session=session, key=key)
|
||||||
|
print("response 1: ", response)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
try:
|
||||||
|
response = await chat_completion(session=session, key=key)
|
||||||
|
pytest.fail("Should have failed - Key crossed it's budget")
|
||||||
|
except Exception as e:
|
||||||
|
assert "ExceededTokenBudget: Current spend for token:" in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue