Merge branch 'main' into litellm_budget_per_key

This commit is contained in:
Ishaan Jaff 2024-01-22 15:49:57 -08:00 committed by GitHub
commit db68774d60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 731 additions and 183 deletions

View file

@ -0,0 +1,34 @@
import Image from '@theme/IdealImage';
# Custom Pricing - Sagemaker, etc.
Use this to register custom pricing (cost per token or cost per second) for models.
## Quick Start
Register custom pricing for sagemaker completion + embedding models.
For cost per second pricing, you **just** need to register `input_cost_per_second`.
**Step 1: Add pricing to config.yaml**
```yaml
model_list:
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: sagemaker-embedding-model
litellm_params:
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
input_cost_per_second: 0.000420
```
**Step 2: Start proxy**
```bash
litellm /path/to/config.yaml
```
**Step 3: View Spend Logs**
<Image img={require('../../img/spend_logs_table.png')} />

View file

@ -440,6 +440,97 @@ general_settings:
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
## Custom /key/generate
If you need to add custom logic before generating a Proxy API Key (Example Validating `team_id`)
### 1. Write a custom `custom_generate_key_fn`
The input to the custom_generate_key_fn function is a single parameter: `data` [(Type: GenerateKeyRequest)](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/_types.py#L125)
The output of your `custom_generate_key_fn` should be a dictionary with the following structure
```python
{
"decision": False,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}
```
- decision (Type: bool): A boolean value indicating whether the key generation is allowed (True) or not (False).
- message (Type: str, Optional): An optional message providing additional information about the decision. This field is included when the decision is False.
```python
async def custom_generate_key_fn(data: GenerateKeyRequest)-> dict:
"""
Asynchronous function for generating a key based on the input data.
Args:
data (GenerateKeyRequest): The input data for key generation.
Returns:
dict: A dictionary containing the decision and an optional message.
{
"decision": False,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}
"""
# decide if a key should be generated or not
print("using custom auth function!")
data_json = data.json() # type: ignore
# Unpacking variables
team_id = data_json.get("team_id")
duration = data_json.get("duration")
models = data_json.get("models")
aliases = data_json.get("aliases")
config = data_json.get("config")
spend = data_json.get("spend")
user_id = data_json.get("user_id")
max_parallel_requests = data_json.get("max_parallel_requests")
metadata = data_json.get("metadata")
tpm_limit = data_json.get("tpm_limit")
rpm_limit = data_json.get("rpm_limit")
if team_id is not None and team_id == "litellm-core-infra@gmail.com":
# only team_id="litellm-core-infra@gmail.com" can make keys
return {
"decision": True,
}
else:
print("Failed custom auth")
return {
"decision": False,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}
```
### 2. Pass the filepath (relative to the config.yaml)
Pass the filepath to the config.yaml
e.g. if they're both in the same dir - `./config.yaml` and `./custom_auth.py`, this is what it looks like:
```yaml
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
custom_key_generate: custom_auth.custom_generate_key_fn
```
## [BETA] Dynamo DB ## [BETA] Dynamo DB

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

View file

@ -139,6 +139,7 @@ const sidebars = {
"items": [ "items": [
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
"proxy/custom_pricing"
] ]
}, },
"proxy/deploy", "proxy/deploy",

View file

@ -12,15 +12,6 @@ formatter = logging.Formatter("\033[92m%(name)s - %(levelname)s\033[0m: %(messag
handler.setFormatter(formatter) handler.setFormatter(formatter)
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
except:
pass
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy") verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
verbose_router_logger = logging.getLogger("LiteLLM Router") verbose_router_logger = logging.getLogger("LiteLLM Router")
verbose_logger = logging.getLogger("LiteLLM") verbose_logger = logging.getLogger("LiteLLM")
@ -29,3 +20,18 @@ verbose_logger = logging.getLogger("LiteLLM")
verbose_router_logger.addHandler(handler) verbose_router_logger.addHandler(handler)
verbose_proxy_logger.addHandler(handler) verbose_proxy_logger.addHandler(handler)
verbose_logger.addHandler(handler) verbose_logger.addHandler(handler)
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to debug
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
except:
pass

View file

@ -629,12 +629,23 @@ class AzureChatCompletion(BaseLLM):
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), transport=AsyncCustomHTTPTransport(),
) )
openai_aclient = AsyncAzureOpenAI( azure_client = AsyncAzureOpenAI(
http_client=client_session, **azure_client_params http_client=client_session, **azure_client_params
) )
else: else:
openai_aclient = client azure_client = client
response = await openai_aclient.images.generate(**data, timeout=timeout) ## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"api_key": azure_client.api_key},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
response = await azure_client.images.generate(**data, timeout=timeout)
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -719,7 +730,7 @@ class AzureChatCompletion(BaseLLM):
input=prompt, input=prompt,
api_key=azure_client.api_key, api_key=azure_client.api_key,
additional_args={ additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "headers": {"api_key": azure_client.api_key},
"api_base": azure_client._base_url._uri_reference, "api_base": azure_client._base_url._uri_reference,
"acompletion": False, "acompletion": False,
"complete_input_dict": data, "complete_input_dict": data,

View file

@ -43,7 +43,7 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
request=request, request=request,
) )
time.sleep(int(response.headers.get("retry-after")) or 10) await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await super().handle_async_request(request) response = await super().handle_async_request(request)
await response.aread() await response.aread()
@ -95,7 +95,6 @@ class CustomHTTPTransport(httpx.HTTPTransport):
request.method = "GET" request.method = "GET"
response = super().handle_request(request) response = super().handle_request(request)
response.read() response.read()
timeout_secs: int = 120 timeout_secs: int = 120
start_time = time.time() start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
@ -112,11 +111,9 @@ class CustomHTTPTransport(httpx.HTTPTransport):
content=json.dumps(timeout).encode("utf-8"), content=json.dumps(timeout).encode("utf-8"),
request=request, request=request,
) )
time.sleep(int(response.headers.get("retry-after", None) or 10))
time.sleep(int(response.headers.get("retry-after")) or 10)
response = super().handle_request(request) response = super().handle_request(request)
response.read() response.read()
if response.json()["status"] == "failed": if response.json()["status"] == "failed":
error_data = response.json() error_data = response.json()
return httpx.Response( return httpx.Response(

View file

@ -348,6 +348,13 @@ def mock_completion(
prompt_tokens=10, completion_tokens=20, total_tokens=30 prompt_tokens=10, completion_tokens=20, total_tokens=30
) )
try:
_, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model)
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
except:
# dont let setting a hidden param block a mock_respose
pass
return model_response return model_response
except: except:
@ -450,6 +457,8 @@ def completion(
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None) input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None)
input_cost_per_second = kwargs.get("input_cost_per_second", None)
output_cost_per_second = kwargs.get("output_cost_per_second", None)
### CUSTOM PROMPT TEMPLATE ### ### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("initial_prompt_value", None) initial_prompt_value = kwargs.get("initial_prompt_value", None)
roles = kwargs.get("roles", None) roles = kwargs.get("roles", None)
@ -527,6 +536,8 @@ def completion(
"tpm", "tpm",
"input_cost_per_token", "input_cost_per_token",
"output_cost_per_token", "output_cost_per_token",
"input_cost_per_second",
"output_cost_per_second",
"hf_model_name", "hf_model_name",
"model_info", "model_info",
"proxy_server_request", "proxy_server_request",
@ -589,6 +600,19 @@ def completion(
} }
} }
) )
if (
input_cost_per_second is not None
): # time based pricing just needs cost in place
output_cost_per_second = output_cost_per_second or 0.0
litellm.register_model(
{
model: {
"input_cost_per_second": input_cost_per_second,
"output_cost_per_second": output_cost_per_second,
"litellm_provider": custom_llm_provider,
}
}
)
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
custom_prompt_dict = {} # type: ignore custom_prompt_dict = {} # type: ignore
if ( if (
@ -2240,6 +2264,11 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None) encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None) aembedding = kwargs.get("aembedding", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
input_cost_per_second = kwargs.get("input_cost_per_second", None)
output_cost_per_second = kwargs.get("output_cost_per_second", None)
openai_params = [ openai_params = [
"user", "user",
"request_timeout", "request_timeout",
@ -2288,6 +2317,8 @@ def embedding(
"tpm", "tpm",
"input_cost_per_token", "input_cost_per_token",
"output_cost_per_token", "output_cost_per_token",
"input_cost_per_second",
"output_cost_per_second",
"hf_model_name", "hf_model_name",
"proxy_server_request", "proxy_server_request",
"model_info", "model_info",
@ -2313,6 +2344,28 @@ def embedding(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
**non_default_params, **non_default_params,
) )
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
{
model: {
"input_cost_per_token": input_cost_per_token,
"output_cost_per_token": output_cost_per_token,
"litellm_provider": custom_llm_provider,
}
}
)
if input_cost_per_second is not None: # time based pricing just needs cost in place
output_cost_per_second = output_cost_per_second or 0.0
litellm.register_model(
{
model: {
"input_cost_per_second": input_cost_per_second,
"output_cost_per_second": output_cost_per_second,
"litellm_provider": custom_llm_provider,
}
}
)
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj
@ -3281,7 +3334,9 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
return response return response
def stream_chunk_builder(chunks: list, messages: Optional[list] = None): def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
):
model_response = litellm.ModelResponse() model_response = litellm.ModelResponse()
# set hidden params from chunk to model_response # set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"): if model_response is not None and hasattr(model_response, "_hidden_params"):
@ -3456,5 +3511,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
) )
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=response, model_response_object=model_response response_object=response,
model_response_object=model_response,
start_time=start_time,
end_time=end_time,
) )

View file

@ -1,4 +1,4 @@
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
@ -14,3 +14,40 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
raise Exception raise Exception
except: except:
raise Exception raise Exception
async def generate_key_fn(data: GenerateKeyRequest):
"""
Asynchronously decides if a key should be generated or not based on the provided data.
Args:
data (GenerateKeyRequest): The data to be used for decision making.
Returns:
bool: True if a key should be generated, False otherwise.
"""
# decide if a key should be generated or not
data_json = data.json() # type: ignore
# Unpacking variables
team_id = data_json.get("team_id")
duration = data_json.get("duration")
models = data_json.get("models")
aliases = data_json.get("aliases")
config = data_json.get("config")
spend = data_json.get("spend")
user_id = data_json.get("user_id")
max_parallel_requests = data_json.get("max_parallel_requests")
metadata = data_json.get("metadata")
tpm_limit = data_json.get("tpm_limit")
rpm_limit = data_json.get("rpm_limit")
if team_id is not None and len(team_id) > 0:
return {
"decision": True,
}
else:
return {
"decision": True,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}

View file

@ -62,8 +62,9 @@ litellm_settings:
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
# general_settings: general_settings:
# master_key: sk-1234 master_key: sk-1234
custom_key_generate: custom_auth.generate_key_fn
# database_type: "dynamo_db" # database_type: "dynamo_db"
# database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
# "billing_mode": "PAY_PER_REQUEST", # "billing_mode": "PAY_PER_REQUEST",

View file

@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None
use_background_health_checks = None use_background_health_checks = None
use_queue = False use_queue = False
health_check_interval = None health_check_interval = None
@ -584,7 +585,7 @@ async def track_cost_callback(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
verbose_proxy_logger.debug( verbose_proxy_logger.info(
f"streaming response_cost {response_cost}, for user_id {user_id}" f"streaming response_cost {response_cost}, for user_id {user_id}"
) )
if user_api_key and ( if user_api_key and (
@ -609,7 +610,7 @@ async def track_cost_callback(
user_id = user_id or kwargs["litellm_params"]["metadata"].get( user_id = user_id or kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
verbose_proxy_logger.debug( verbose_proxy_logger.info(
f"response_cost {response_cost}, for user_id {user_id}" f"response_cost {response_cost}, for user_id {user_id}"
) )
if user_api_key and ( if user_api_key and (
@ -896,7 +897,7 @@ class ProxyConfig:
""" """
Load config values into proxy global state Load config values into proxy global state
""" """
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -1074,6 +1075,12 @@ class ProxyConfig:
user_custom_auth = get_instance_fn( user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path value=custom_auth, config_file_path=config_file_path
) )
custom_key_generate = general_settings.get("custom_key_generate", None)
if custom_key_generate is not None:
user_custom_key_generate = get_instance_fn(
value=custom_key_generate, config_file_path=config_file_path
)
## dynamodb ## dynamodb
database_type = general_settings.get("database_type", None) database_type = general_settings.get("database_type", None)
if database_type is not None and ( if database_type is not None and (
@ -2189,7 +2196,16 @@ async def generate_key_fn(
- expires: (datetime) Datetime object for when key expires. - expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
""" """
global user_custom_key_generate
verbose_proxy_logger.debug("entered /key/generate") verbose_proxy_logger.debug("entered /key/generate")
if user_custom_key_generate is not None:
result = await user_custom_key_generate(data)
decision = result.get("decision", True)
message = result.get("message", "Authentication Failed - Custom Auth Rule")
if not decision:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message)
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
@ -2978,7 +2994,7 @@ async def get_routes():
@router.on_event("shutdown") @router.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
global prisma_client, master_key, user_custom_auth global prisma_client, master_key, user_custom_auth, user_custom_key_generate
if prisma_client: if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma") verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect() await prisma_client.disconnect()
@ -2988,7 +3004,7 @@ async def shutdown_event():
def cleanup_router_config_variables(): def cleanup_router_config_variables():
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client
# Set all variables to None # Set all variables to None
master_key = None master_key = None
@ -2996,6 +3012,7 @@ def cleanup_router_config_variables():
otel_logging = None otel_logging = None
user_custom_auth = None user_custom_auth = None
user_custom_auth_path = None user_custom_auth_path = None
user_custom_key_generate = None
use_background_health_checks = None use_background_health_checks = None
health_check_interval = None health_check_interval = None
prisma_client = None prisma_client = None

View file

@ -449,6 +449,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into Keys Table")
return new_verification_token return new_verification_token
elif table_name == "user": elif table_name == "user":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
@ -459,6 +460,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into User Table")
return new_user_row return new_user_row
elif table_name == "config": elif table_name == "config":
""" """
@ -483,6 +485,7 @@ class PrismaClient:
tasks.append(updated_table_row) tasks.append(updated_table_row)
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
verbose_proxy_logger.info(f"Data Inserted into Config Table")
elif table_name == "spend": elif table_name == "spend":
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
new_spend_row = await self.db.litellm_spendlogs.upsert( new_spend_row = await self.db.litellm_spendlogs.upsert(
@ -492,6 +495,7 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
verbose_proxy_logger.info(f"Data Inserted into Spend Table")
return new_spend_row return new_spend_row
except Exception as e: except Exception as e:

View file

@ -997,6 +997,9 @@ class Router:
""" """
try: try:
kwargs["model"] = mg kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_retries( response = await self.async_function_with_retries(
*args, **kwargs *args, **kwargs
) )
@ -1025,8 +1028,10 @@ class Router:
f"Falling back to model_group = {mg}" f"Falling back to model_group = {mg}"
) )
kwargs["model"] = mg kwargs["model"] = mg
kwargs["metadata"]["model_group"] = mg kwargs.setdefault("metadata", {}).update(
response = await self.async_function_with_retries( {"model_group": mg}
) # update model_group used, if fallbacks are done
response = await self.async_function_with_fallbacks(
*args, **kwargs *args, **kwargs
) )
return response return response
@ -1191,6 +1196,9 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs) response = self.function_with_fallbacks(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:
@ -1214,6 +1222,9 @@ class Router:
## LOGGING ## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
kwargs["model"] = mg kwargs["model"] = mg
kwargs.setdefault("metadata", {}).update(
{"model_group": mg}
) # update model_group used, if fallbacks are done
response = self.function_with_fallbacks(*args, **kwargs) response = self.function_with_fallbacks(*args, **kwargs)
return response return response
except Exception as e: except Exception as e:

View file

@ -1372,16 +1372,21 @@ def test_customprompt_together_ai():
def test_completion_sagemaker(): def test_completion_sagemaker():
try: try:
print("testing sagemaker")
litellm.set_verbose = True litellm.set_verbose = True
print("testing sagemaker")
response = completion( response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages, messages=messages,
temperature=0.2, temperature=0.2,
max_tokens=80, max_tokens=80,
input_cost_per_second=0.000420,
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
cost = completion_cost(completion_response=response)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single completion call
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -1,56 +1,58 @@
### What this tests #### ### What this tests ####
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
import pytest import pytest
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion, embedding from litellm import completion, embedding
import litellm import litellm
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
complete_streaming_response_in_callback = "" complete_streaming_response_in_callback = ""
def __init__(self): def __init__(self):
self.success: bool = False # type: ignore self.success: bool = False # type: ignore
self.failure: bool = False # type: ignore self.failure: bool = False # type: ignore
self.async_success: bool = False # type: ignore self.async_success: bool = False # type: ignore
self.async_success_embedding: bool = False # type: ignore self.async_success_embedding: bool = False # type: ignore
self.async_failure: bool = False # type: ignore self.async_failure: bool = False # type: ignore
self.async_failure_embedding: bool = False # type: ignore self.async_failure_embedding: bool = False # type: ignore
self.async_completion_kwargs = None # type: ignore self.async_completion_kwargs = None # type: ignore
self.async_embedding_kwargs = None # type: ignore self.async_embedding_kwargs = None # type: ignore
self.async_embedding_response = None # type: ignore self.async_embedding_response = None # type: ignore
self.async_completion_kwargs_fail = None # type: ignore self.async_completion_kwargs_fail = None # type: ignore
self.async_embedding_kwargs_fail = None # type: ignore self.async_embedding_kwargs_fail = None # type: ignore
self.stream_collected_response = None # type: ignore self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore self.sync_stream_collected_response = None # type: ignore
self.user = None # type: ignore self.user = None # type: ignore
self.data_sent_to_api: dict = {} self.data_sent_to_api: dict = {}
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print(f"Pre-API Call")
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {}) self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Success") print(f"On Success")
self.success = True self.success = True
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
self.sync_stream_collected_response = response_obj self.sync_stream_collected_response = response_obj
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Failure") print(f"On Failure")
self.failure = True self.failure = True
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success") print(f"On Async success")
print(f"received kwargs user: {kwargs['user']}") print(f"received kwargs user: {kwargs['user']}")
self.async_success = True self.async_success = True
@ -62,24 +64,30 @@ class MyCustomHandler(CustomLogger):
self.stream_collected_response = response_obj self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs self.async_completion_kwargs = kwargs
self.user = kwargs.get("user", None) self.user = kwargs.get("user", None)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure") print(f"On Async Failure")
self.async_failure = True self.async_failure = True
if kwargs.get("model") == "text-embedding-ada-002": if kwargs.get("model") == "text-embedding-ada-002":
self.async_failure_embedding = True self.async_failure_embedding = True
self.async_embedding_kwargs_fail = kwargs self.async_embedding_kwargs_fail = kwargs
self.async_completion_kwargs_fail = kwargs self.async_completion_kwargs_fail = kwargs
class TmpFunction: class TmpFunction:
complete_streaming_response_in_callback = "" complete_streaming_response_in_callback = ""
async_success: bool = False async_success: bool = False
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time): async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
print(f"ON ASYNC LOGGING") print(f"ON ASYNC LOGGING")
self.async_success = True self.async_success = True
print(f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}') print(
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response") f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}'
)
self.complete_streaming_response_in_callback = kwargs.get(
"complete_streaming_response"
)
def test_async_chat_openai_stream(): def test_async_chat_openai_stream():
@ -88,29 +96,39 @@ def test_async_chat_openai_stream():
# litellm.set_verbose = True # litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn] litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = "" complete_streaming_response = ""
async def call_gpt(): async def call_gpt():
nonlocal complete_streaming_response nonlocal complete_streaming_response
response = await litellm.acompletion(model="gpt-3.5-turbo", response = await litellm.acompletion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" stream=True,
}], )
stream=True) async for chunk in response:
async for chunk in response: complete_streaming_response += (
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" chunk["choices"][0]["delta"]["content"] or ""
)
print(complete_streaming_response) print(complete_streaming_response)
asyncio.run(call_gpt()) asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'") complete_streaming_response = complete_streaming_response.strip("'")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"] response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
"message"
]["content"]
response2 = complete_streaming_response response2 = complete_streaming_response
# assert [ord(c) for c in response1] == [ord(c) for c in response2] # assert [ord(c) for c in response1] == [ord(c) for c in response2]
print(f"response1: {response1}")
print(f"response2: {response2}")
assert response1 == response2 assert response1 == response2
assert tmp_function.async_success == True assert tmp_function.async_success == True
except Exception as e: except Exception as e:
print(e) print(e)
pytest.fail(f"An error occurred - {str(e)}") pytest.fail(f"An error occurred - {str(e)}")
# test_async_chat_openai_stream() # test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure(): def test_completion_azure_stream_moderation_failure():
try: try:
customHandler = MyCustomHandler() customHandler = MyCustomHandler()
@ -122,11 +140,11 @@ def test_completion_azure_stream_moderation_failure():
"content": "how do i kill someone", "content": "how do i kill someone",
}, },
] ]
try: try:
response = completion( response = completion(
model="azure/chatgpt-v-2", messages=messages, stream=True model="azure/chatgpt-v-2", messages=messages, stream=True
) )
for chunk in response: for chunk in response:
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
continue continue
except Exception as e: except Exception as e:
@ -139,7 +157,7 @@ def test_completion_azure_stream_moderation_failure():
def test_async_custom_handler_stream(): def test_async_custom_handler_stream():
try: try:
# [PROD Test] - Do not DELETE # [PROD Test] - Do not DELETE
# checks if the model response available in the async + stream callbacks is equal to the received response # checks if the model response available in the async + stream callbacks is equal to the received response
customHandler2 = MyCustomHandler() customHandler2 = MyCustomHandler()
litellm.callbacks = [customHandler2] litellm.callbacks = [customHandler2]
@ -152,32 +170,37 @@ def test_async_custom_handler_stream():
}, },
] ]
complete_streaming_response = "" complete_streaming_response = ""
async def test_1(): async def test_1():
nonlocal complete_streaming_response nonlocal complete_streaming_response
response = await litellm.acompletion( response = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2", messages=messages, stream=True
messages=messages,
stream=True
) )
async for chunk in response: async for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" complete_streaming_response += (
chunk["choices"][0]["delta"]["content"] or ""
)
print(complete_streaming_response) print(complete_streaming_response)
asyncio.run(test_1()) asyncio.run(test_1())
response_in_success_handler = customHandler2.stream_collected_response response_in_success_handler = customHandler2.stream_collected_response
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"] response_in_success_handler = response_in_success_handler["choices"][0][
"message"
]["content"]
print("\n\n") print("\n\n")
print("response_in_success_handler: ", response_in_success_handler) print("response_in_success_handler: ", response_in_success_handler)
print("complete_streaming_response: ", complete_streaming_response) print("complete_streaming_response: ", complete_streaming_response)
assert response_in_success_handler == complete_streaming_response assert response_in_success_handler == complete_streaming_response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_async_custom_handler_stream() # test_async_custom_handler_stream()
def test_azure_completion_stream(): def test_azure_completion_stream():
# [PROD Test] - Do not DELETE # [PROD Test] - Do not DELETE
# test if completion() + sync custom logger get the same complete stream response # test if completion() + sync custom logger get the same complete stream response
try: try:
# checks if the model response available in the async + stream callbacks is equal to the received response # checks if the model response available in the async + stream callbacks is equal to the received response
@ -194,17 +217,17 @@ def test_azure_completion_stream():
complete_streaming_response = "" complete_streaming_response = ""
response = litellm.completion( response = litellm.completion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2", messages=messages, stream=True
messages=messages,
stream=True
) )
for chunk in response: for chunk in response:
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
print(complete_streaming_response) print(complete_streaming_response)
time.sleep(0.5) # wait 1/2 second before checking callbacks time.sleep(0.5) # wait 1/2 second before checking callbacks
response_in_success_handler = customHandler2.sync_stream_collected_response response_in_success_handler = customHandler2.sync_stream_collected_response
response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"] response_in_success_handler = response_in_success_handler["choices"][0][
"message"
]["content"]
print("\n\n") print("\n\n")
print("response_in_success_handler: ", response_in_success_handler) print("response_in_success_handler: ", response_in_success_handler)
print("complete_streaming_response: ", complete_streaming_response) print("complete_streaming_response: ", complete_streaming_response)
@ -212,24 +235,32 @@ def test_azure_completion_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_custom_handler_completion(): async def test_async_custom_handler_completion():
try: try:
customHandler_success = MyCustomHandler() customHandler_success = MyCustomHandler()
customHandler_failure = MyCustomHandler() customHandler_failure = MyCustomHandler()
# success # success
assert customHandler_success.async_success == False assert customHandler_success.async_success == False
litellm.callbacks = [customHandler_success] litellm.callbacks = [customHandler_success]
response = await litellm.acompletion( response = await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": "hello from litellm test", "content": "hello from litellm test",
}] }
) ],
)
await asyncio.sleep(1) await asyncio.sleep(1)
assert customHandler_success.async_success == True, "async success is not set to True even after success" assert (
assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo" customHandler_success.async_success == True
), "async success is not set to True even after success"
assert (
customHandler_success.async_completion_kwargs.get("model")
== "gpt-3.5-turbo"
)
# failure # failure
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
messages = [ messages = [
@ -240,80 +271,119 @@ async def test_async_custom_handler_completion():
}, },
] ]
assert customHandler_failure.async_failure == False assert customHandler_failure.async_failure == False
try: try:
response = await litellm.acompletion( response = await litellm.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=messages, messages=messages,
api_key="my-bad-key", api_key="my-bad-key",
) )
except: except:
pass pass
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure" assert (
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo" customHandler_failure.async_failure == True
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 ), "async failure is not set to True even after failure"
assert (
customHandler_failure.async_completion_kwargs_fail.get("model")
== "gpt-3.5-turbo"
)
assert (
len(
str(customHandler_failure.async_completion_kwargs_fail.get("exception"))
)
> 10
) # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
litellm.callbacks = [] litellm.callbacks = []
print("Passed setting async failure") print("Passed setting async failure")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_custom_handler_completion()) # asyncio.run(test_async_custom_handler_completion())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_custom_handler_embedding(): async def test_async_custom_handler_embedding():
try: try:
customHandler_embedding = MyCustomHandler() customHandler_embedding = MyCustomHandler()
litellm.callbacks = [customHandler_embedding] litellm.callbacks = [customHandler_embedding]
# success # success
assert customHandler_embedding.async_success_embedding == False assert customHandler_embedding.async_success_embedding == False
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = ["hello world"], input=["hello world"],
) )
await asyncio.sleep(1) await asyncio.sleep(1)
assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success" assert (
assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002" customHandler_embedding.async_success_embedding == True
assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2 ), "async_success_embedding is not set to True even after success"
assert (
customHandler_embedding.async_embedding_kwargs.get("model")
== "text-embedding-ada-002"
)
assert (
customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"]
== 2
)
print("Passed setting async success: Embedding") print("Passed setting async success: Embedding")
# failure # failure
assert customHandler_embedding.async_failure_embedding == False assert customHandler_embedding.async_failure_embedding == False
try: try:
response = await litellm.aembedding( response = await litellm.aembedding(
model="text-embedding-ada-002", model="text-embedding-ada-002",
input = ["hello world"], input=["hello world"],
api_key="my-bad-key", api_key="my-bad-key",
) )
except: except:
pass pass
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure" assert (
assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002" customHandler_embedding.async_failure_embedding == True
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 ), "async failure embedding is not set to True even after failure"
assert (
customHandler_embedding.async_embedding_kwargs_fail.get("model")
== "text-embedding-ada-002"
)
assert (
len(
str(
customHandler_embedding.async_embedding_kwargs_fail.get("exception")
)
)
> 10
) # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_custom_handler_embedding()) # asyncio.run(test_async_custom_handler_embedding())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param(): async def test_async_custom_handler_embedding_optional_param():
""" """
Tests if the openai optional params for embedding - user + encoding_format, Tests if the openai optional params for embedding - user + encoding_format,
are logged are logged
""" """
customHandler_optional_params = MyCustomHandler() customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params] litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding( response = await litellm.aembedding(
model="azure/azure-embedding-model", model="azure/azure-embedding-model", input=["hello world"], user="John"
input = ["hello world"], )
user = "John" await asyncio.sleep(1) # success callback is async
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John" assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"] assert (
customHandler_optional_params.user
== customHandler_optional_params.data_sent_to_api["user"]
)
# asyncio.run(test_async_custom_handler_embedding_optional_param()) # asyncio.run(test_async_custom_handler_embedding_optional_param())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param_bedrock(): async def test_async_custom_handler_embedding_optional_param_bedrock():
""" """
Tests if the openai optional params for embedding - user + encoding_format, Tests if the openai optional params for embedding - user + encoding_format,
are logged are logged
but makes sure these are not sent to the non-openai/azure endpoint (raises errors). but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
@ -323,42 +393,68 @@ async def test_async_custom_handler_embedding_optional_param_bedrock():
customHandler_optional_params = MyCustomHandler() customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params] litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding( response = await litellm.aembedding(
model="bedrock/amazon.titan-embed-text-v1", model="bedrock/amazon.titan-embed-text-v1", input=["hello world"], user="John"
input = ["hello world"], )
user = "John" await asyncio.sleep(1) # success callback is async
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John" assert customHandler_optional_params.user == "John"
assert "user" not in customHandler_optional_params.data_sent_to_api assert "user" not in customHandler_optional_params.data_sent_to_api
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
from litellm import Cache from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random import random
try: try:
print("\nrunning test_redis_cache_completion_stream") print("\nrunning test_redis_cache_completion_stream")
litellm.set_verbose = True litellm.set_verbose = True
random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache random_number = random.randint(
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] 1, 100000
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) ) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
print("test for caching, streaming + completion") print("test for caching, streaming + completion")
response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_1_content = "" response_1_content = ""
for chunk in response1: for chunk in response1:
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur
response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=40,
temperature=0.2,
stream=True,
)
response_2_content = "" response_2_content = ""
for chunk in response2: for chunk in response2:
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" assert (
response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.cache = None litellm.cache = None
@ -366,4 +462,6 @@ def test_redis_cache_completion_stream():
print(e) print(e)
litellm.success_callback = [] litellm.success_callback = []
raise e raise e
# test_redis_cache_completion_stream()
# test_redis_cache_completion_stream()

View file

@ -10,7 +10,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import embedding, completion from litellm import embedding, completion, completion_cost
litellm.set_verbose = False litellm.set_verbose = False
@ -341,8 +341,30 @@ def test_sagemaker_embeddings():
response = litellm.embedding( response = litellm.embedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
input_cost_per_second=0.000420,
) )
print(f"response: {response}") print(f"response: {response}")
cost = completion_cost(completion_response=response)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single embedding call
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_sagemaker_aembeddings():
try:
response = await litellm.aembedding(
model="sagemaker/berri-benchmarking-gpt-j-6b-fp16",
input=["good morning from litellm", "this is another item"],
input_cost_per_second=0.000420,
)
print(f"response: {response}")
cost = completion_cost(completion_response=response)
assert (
cost > 0.0 and cost < 1.0
) # should never be > $1 for a single embedding call
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -35,6 +35,7 @@ import pytest, logging, asyncio
import litellm, asyncio import litellm, asyncio
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
new_user, new_user,
generate_key_fn,
user_api_key_auth, user_api_key_auth,
user_update, user_update,
delete_key_fn, delete_key_fn,
@ -53,6 +54,7 @@ from litellm.proxy._types import (
DynamoDBArgs, DynamoDBArgs,
DeleteKeyRequest, DeleteKeyRequest,
UpdateKeyRequest, UpdateKeyRequest,
GenerateKeyRequest,
) )
from litellm.proxy.utils import DBClient from litellm.proxy.utils import DBClient
from starlette.datastructures import URL from starlette.datastructures import URL
@ -597,6 +599,85 @@ def test_generate_and_update_key(prisma_client):
print("Got Exception", e) print("Got Exception", e)
print(e.detail) print(e.detail)
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
def test_key_generate_with_custom_auth(prisma_client):
# custom - generate key function
async def custom_generate_key_fn(data: GenerateKeyRequest) -> dict:
"""
Asynchronous function for generating a key based on the input data.
Args:
data (GenerateKeyRequest): The input data for key generation.
Returns:
dict: A dictionary containing the decision and an optional message.
{
"decision": False,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}
"""
# decide if a key should be generated or not
print("using custom auth function!")
data_json = data.json() # type: ignore
# Unpacking variables
team_id = data_json.get("team_id")
duration = data_json.get("duration")
models = data_json.get("models")
aliases = data_json.get("aliases")
config = data_json.get("config")
spend = data_json.get("spend")
user_id = data_json.get("user_id")
max_parallel_requests = data_json.get("max_parallel_requests")
metadata = data_json.get("metadata")
tpm_limit = data_json.get("tpm_limit")
rpm_limit = data_json.get("rpm_limit")
if team_id is not None and team_id == "litellm-core-infra@gmail.com":
# only team_id="litellm-core-infra@gmail.com" can make keys
return {
"decision": True,
}
else:
print("Failed custom auth")
return {
"decision": False,
"message": "This violates LiteLLM Proxy Rules. No team id provided.",
}
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(
litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn
)
try:
request = GenerateKeyRequest()
key = await generate_key_fn(request)
pytest.fail(f"Expected an exception. Got {key}")
except Exception as e:
# this should fail
print("Got Exception", e)
print(e.detail)
print("First request failed!. This is expected")
assert (
"This violates LiteLLM Proxy Rules. No team id provided."
in e.detail
)
request_2 = GenerateKeyRequest(
team_id="litellm-core-infra@gmail.com",
)
key = await generate_key_fn(request_2)
print(key)
generated_key = key.key
asyncio.run(test())
except Exception as e:
print("Got Exception", e)
print(e.detail)
pytest.fail(f"An exception occurred - {str(e)}")
def test_call_with_key_over_budget(prisma_client): def test_call_with_key_over_budget(prisma_client):

View file

@ -716,7 +716,7 @@ def test_usage_based_routing_fallbacks():
# Constants for TPM and RPM allocation # Constants for TPM and RPM allocation
AZURE_FAST_TPM = 3 AZURE_FAST_TPM = 3
AZURE_BASIC_TPM = 4 AZURE_BASIC_TPM = 4
OPENAI_TPM = 2000 OPENAI_TPM = 400
ANTHROPIC_TPM = 100000 ANTHROPIC_TPM = 100000
def get_azure_params(deployment_name: str): def get_azure_params(deployment_name: str):
@ -775,6 +775,7 @@ def test_usage_based_routing_fallbacks():
model_list=model_list, model_list=model_list,
fallbacks=fallbacks_list, fallbacks=fallbacks_list,
set_verbose=True, set_verbose=True,
debug_level="DEBUG",
routing_strategy="usage-based-routing", routing_strategy="usage-based-routing",
redis_host=os.environ["REDIS_HOST"], redis_host=os.environ["REDIS_HOST"],
redis_port=os.environ["REDIS_PORT"], redis_port=os.environ["REDIS_PORT"],
@ -783,17 +784,32 @@ def test_usage_based_routing_fallbacks():
messages = [ messages = [
{"content": "Tell me a joke.", "role": "user"}, {"content": "Tell me a joke.", "role": "user"},
] ]
response = router.completion( response = router.completion(
model="azure/gpt-4-fast", messages=messages, timeout=5 model="azure/gpt-4-fast",
messages=messages,
timeout=5,
mock_response="very nice to meet you",
) )
print("response: ", response) print("response: ", response)
print("response._hidden_params: ", response._hidden_params) print("response._hidden_params: ", response._hidden_params)
# in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass
# the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM
assert response._hidden_params["custom_llm_provider"] == "openai" assert response._hidden_params["custom_llm_provider"] == "openai"
# now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2
for i in range(20):
response = router.completion(
model="azure/gpt-4-fast",
messages=messages,
timeout=5,
mock_response="very nice to meet you",
)
print("response: ", response)
print("response._hidden_params: ", response._hidden_params)
if i == 19:
# by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2
assert response._hidden_params["custom_llm_provider"] == "anthropic"
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred {e}") pytest.fail(f"An exception occurred {e}")

View file

@ -765,6 +765,7 @@ class Logging:
self.litellm_call_id = litellm_call_id self.litellm_call_id = litellm_call_id
self.function_id = function_id self.function_id = function_id
self.streaming_chunks = [] # for generating complete stream response self.streaming_chunks = [] # for generating complete stream response
self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {} self.model_call_details = {}
def update_environment_variables( def update_environment_variables(
@ -828,7 +829,7 @@ class Logging:
[f"-H '{k}: {v}'" for k, v in masked_headers.items()] [f"-H '{k}: {v}'" for k, v in masked_headers.items()]
) )
print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") verbose_logger.debug(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}")
curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
curl_command += "curl -X POST \\\n" curl_command += "curl -X POST \\\n"
@ -994,13 +995,10 @@ class Logging:
self.model_call_details["log_event_type"] = "post_api_call" self.model_call_details["log_event_type"] = "post_api_call"
# User Logging -> if you pass in a custom logging function # User Logging -> if you pass in a custom logging function
print_verbose( verbose_logger.debug(
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n" f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n"
) )
print_verbose( verbose_logger.debug(
f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
)
print_verbose(
f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}" f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}"
) )
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
@ -1094,20 +1092,20 @@ class Logging:
if ( if (
result.choices[0].finish_reason is not None result.choices[0].finish_reason is not None
): # if it's the last chunk ): # if it's the last chunk
self.streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}") # print_verbose(f"final set of received chunks: {self.sync_streaming_chunks}")
try: try:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks, self.sync_streaming_chunks,
messages=self.model_call_details.get("messages", None), messages=self.model_call_details.get("messages", None),
) )
except: except:
complete_streaming_response = None complete_streaming_response = None
else: else:
self.streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
if complete_streaming_response: if complete_streaming_response:
verbose_logger.info( verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details[ self.model_call_details[
@ -1307,7 +1305,9 @@ class Logging:
) )
== False == False
): # custom logger class ): # custom logger class
print_verbose(f"success callbacks: Running Custom Logger Class") verbose_logger.info(
f"success callbacks: Running SYNC Custom Logger Class"
)
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1329,7 +1329,17 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions elif (
callable(callback) == True
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
) )
@ -1364,6 +1374,9 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Async success callbacks: {litellm._async_success_callback}") print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream: if self.stream:
@ -1374,6 +1387,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks, self.streaming_chunks,
messages=self.model_call_details.get("messages", None), messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
) )
except Exception as e: except Exception as e:
print_verbose( print_verbose(
@ -1387,9 +1402,7 @@ class Logging:
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
for callback in litellm._async_success_callback: for callback in litellm._async_success_callback:
try: try:
if callback == "cache" and litellm.cache is not None: if callback == "cache" and litellm.cache is not None:
@ -1436,7 +1449,6 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
print_verbose(f"Async success callbacks: async_log_event")
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,
@ -2134,7 +2146,7 @@ def client(original_function):
litellm.cache.add_cache(result, *args, **kwargs) litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
print_verbose(f"Wrapper: Completed Call, calling success_handler") verbose_logger.info(f"Wrapper: Completed Call, calling success_handler")
threading.Thread( threading.Thread(
target=logging_obj.success_handler, args=(result, start_time, end_time) target=logging_obj.success_handler, args=(result, start_time, end_time)
).start() ).start()
@ -2373,7 +2385,9 @@ def client(original_function):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None "id", None
) )
if isinstance(result, ModelResponse): if isinstance(result, ModelResponse) or isinstance(
result, EmbeddingResponse
):
result._response_ms = ( result._response_ms = (
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
@ -2806,7 +2820,11 @@ def token_counter(
def cost_per_token( def cost_per_token(
model="", prompt_tokens=0, completion_tokens=0, custom_llm_provider=None model="",
prompt_tokens=0,
completion_tokens=0,
response_time_ms=None,
custom_llm_provider=None,
): ):
""" """
Calculates the cost per token for a given model, prompt tokens, and completion tokens. Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -2828,14 +2846,35 @@ def cost_per_token(
else: else:
model_with_provider = model model_with_provider = model
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map") verbose_logger.debug(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref: if model in model_cost_ref:
prompt_tokens_cost_usd_dollar = ( verbose_logger.debug(f"Success: model={model} in model_cost_map")
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens if (
) model_cost_ref[model].get("input_cost_per_token", None) is not None
completion_tokens_cost_usd_dollar = ( and model_cost_ref[model].get("output_cost_per_token", None) is not None
model_cost_ref[model]["output_cost_per_token"] * completion_tokens ):
## COST PER TOKEN ##
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
elif (
model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None
):
verbose_logger.debug(
f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
)
completion_tokens_cost_usd_dollar = 0.0
verbose_logger.debug(
f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}"
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref: elif model_with_provider in model_cost_ref:
@ -2938,6 +2977,10 @@ def completion_cost(
completion_tokens = completion_response.get("usage", {}).get( completion_tokens = completion_response.get("usage", {}).get(
"completion_tokens", 0 "completion_tokens", 0
) )
total_time = completion_response.get("_response_ms", 0)
verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} "
)
model = ( model = (
model or completion_response["model"] model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
@ -2975,6 +3018,7 @@ def completion_cost(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
response_time_ms=total_time,
) )
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
except Exception as e: except Exception as e:
@ -3005,9 +3049,8 @@ def register_model(model_cost: Union[str, dict]):
for key, value in loaded_model_cost.items(): for key, value in loaded_model_cost.items():
## override / add new keys to the existing model cost dictionary ## override / add new keys to the existing model cost dictionary
if key in litellm.model_cost: litellm.model_cost.setdefault(key, {}).update(value)
for k, v in loaded_model_cost[key].items(): verbose_logger.debug(f"{key} added to model cost map")
litellm.model_cost[key][k] = v
# add new model names to provider lists # add new model names to provider lists
if value.get("litellm_provider") == "openai": if value.get("litellm_provider") == "openai":
if key not in litellm.open_ai_chat_completion_models: if key not in litellm.open_ai_chat_completion_models:
@ -3300,11 +3343,13 @@ def get_optional_params(
) )
def _check_valid_arg(supported_params): def _check_valid_arg(supported_params):
print_verbose( verbose_logger.debug(
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
) )
print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}") verbose_logger.debug(
print_verbose( f"\nLiteLLM: Params passed to completion() {passed_params}"
)
verbose_logger.debug(
f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}"
) )
unsupported_params = {} unsupported_params = {}
@ -5150,6 +5195,8 @@ def convert_to_model_response_object(
"completion", "embedding", "image_generation" "completion", "embedding", "image_generation"
] = "completion", ] = "completion",
stream=False, stream=False,
start_time=None,
end_time=None,
): ):
try: try:
if response_type == "completion" and ( if response_type == "completion" and (
@ -5203,6 +5250,12 @@ def convert_to_model_response_object(
if "model" in response_object: if "model" in response_object:
model_response_object.model = response_object["model"] model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None:
model_response_object._response_ms = ( # type: ignore
end_time - start_time
).total_seconds() * 1000
return model_response_object return model_response_object
elif response_type == "embedding" and ( elif response_type == "embedding" and (
model_response_object is None model_response_object is None
@ -5227,6 +5280,11 @@ def convert_to_model_response_object(
model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if start_time is not None and end_time is not None:
model_response_object._response_ms = ( # type: ignore
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
return model_response_object return model_response_object
elif response_type == "image_generation" and ( elif response_type == "image_generation" and (
model_response_object is None model_response_object is None

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.18.8" version = "1.18.9"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -61,7 +61,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.18.8" version = "1.18.9"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]