Compare commits

..

5 commits

Author SHA1 Message Date
Ishaan Jaff
c94f9f3b1e test_audio_speech_router 2024-11-28 20:21:26 -08:00
Ishaan Jaff
e2787eeefe
Merge branch 'main' into litellm_fix_router_aspeech 2024-11-28 20:18:52 -08:00
Ishaan Jaff
edef33abb2 test_audio_speech_router 2024-11-28 20:17:37 -08:00
Ishaan Jaff
97cd5526ae fix aspeech on router 2024-11-28 20:17:19 -08:00
Ishaan Jaff
7901eee0b7 doc Migrating Databases 2024-11-28 07:33:43 -08:00
40 changed files with 781 additions and 1042 deletions

View file

@ -1408,7 +1408,7 @@ jobs:
command: | command: |
docker run -d \ docker run -d \
-p 4000:4000 \ -p 4000:4000 \
-e DATABASE_URL=$PROXY_DATABASE_URL_2 \ -e DATABASE_URL=$PROXY_DATABASE_URL \
-e LITELLM_MASTER_KEY="sk-1234" \ -e LITELLM_MASTER_KEY="sk-1234" \
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e UI_USERNAME="admin" \ -e UI_USERNAME="admin" \

View file

@ -69,3 +69,24 @@ When disabling spend logs (`disable_spend_logs: True`):
When disabling error logs (`disable_error_logs: True`): When disabling error logs (`disable_error_logs: True`):
- You **will not** be able to view Errors on the LiteLLM UI - You **will not** be able to view Errors on the LiteLLM UI
- You **will** continue seeing error logs in your application logs and any other logging integrations you are using - You **will** continue seeing error logs in your application logs and any other logging integrations you are using
## Migrating Databases
If you need to migrate Databases the following Tables should be copied to ensure continuation of services and no downtime
| Table Name | Description |
|------------|-------------|
| LiteLLM_VerificationToken | **Required** to ensure existing virtual keys continue working |
| LiteLLM_UserTable | **Required** to ensure existing virtual keys continue working |
| LiteLLM_TeamTable | **Required** to ensure Teams are migrated |
| LiteLLM_TeamMembership | **Required** to ensure Teams member budgets are migrated |
| LiteLLM_BudgetTable | **Required** to migrate existing budgeting settings |
| LiteLLM_OrganizationTable | **Optional** Only migrate if you use Organizations in DB |
| LiteLLM_OrganizationMembership | **Optional** Only migrate if you use Organizations in DB |
| LiteLLM_ProxyModelTable | **Optional** Only migrate if you store your LLMs in the DB (i.e you set `STORE_MODEL_IN_DB=True`) |
| LiteLLM_SpendLogs | **Optional** Only migrate if you want historical data on LiteLLM UI |
| LiteLLM_ErrorLogs | **Optional** Only migrate if you want historical data on LiteLLM UI |

View file

@ -192,13 +192,3 @@ Here is a screenshot of the metrics you can monitor with the LiteLLM Grafana Das
|----------------------|--------------------------------------| |----------------------|--------------------------------------|
| `litellm_llm_api_failed_requests_metric` | **deprecated** use `litellm_proxy_failed_requests_metric` | | `litellm_llm_api_failed_requests_metric` | **deprecated** use `litellm_proxy_failed_requests_metric` |
| `litellm_requests_metric` | **deprecated** use `litellm_proxy_total_requests_metric` | | `litellm_requests_metric` | **deprecated** use `litellm_proxy_total_requests_metric` |
## FAQ
### What are `_created` vs. `_total` metrics?
- `_created` metrics are metrics that are created when the proxy starts
- `_total` metrics are metrics that are incremented for each request
You should consume the `_total` metrics for your counting purposes

View file

@ -2,9 +2,7 @@
from typing import Optional, List from typing import Optional, List
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.proxy.proxy_server import PrismaClient, HTTPException from litellm.proxy.proxy_server import PrismaClient, HTTPException
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import collections import collections
import httpx
from datetime import datetime from datetime import datetime
@ -116,6 +114,7 @@ async def ui_get_spend_by_tags(
def _forecast_daily_cost(data: list): def _forecast_daily_cost(data: list):
import requests # type: ignore
from datetime import datetime, timedelta from datetime import datetime, timedelta
if len(data) == 0: if len(data) == 0:
@ -137,17 +136,17 @@ def _forecast_daily_cost(data: list):
print("last entry date", last_entry_date) print("last entry date", last_entry_date)
# Assuming today_date is a datetime object
today_date = datetime.now()
# Calculate the last day of the month # Calculate the last day of the month
last_day_of_todays_month = datetime( last_day_of_todays_month = datetime(
today_date.year, today_date.month % 12 + 1, 1 today_date.year, today_date.month % 12 + 1, 1
) - timedelta(days=1) ) - timedelta(days=1)
print("last day of todays month", last_day_of_todays_month)
# Calculate the remaining days in the month # Calculate the remaining days in the month
remaining_days = (last_day_of_todays_month - last_entry_date).days remaining_days = (last_day_of_todays_month - last_entry_date).days
print("remaining days", remaining_days)
current_spend_this_month = 0 current_spend_this_month = 0
series = {} series = {}
for entry in data: for entry in data:
@ -177,19 +176,13 @@ def _forecast_daily_cost(data: list):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
client = HTTPHandler() response = requests.post(
url="https://trend-api-production.up.railway.app/forecast",
try: json=payload,
response = client.post( headers=headers,
url="https://trend-api-production.up.railway.app/forecast", )
json=payload, # check the status code
headers=headers, response.raise_for_status()
)
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=500,
detail={"error": f"Error getting forecast: {e.response.text}"},
)
json_response = response.json() json_response = response.json()
forecast_data = json_response["forecast"] forecast_data = json_response["forecast"]
@ -213,3 +206,13 @@ def _forecast_daily_cost(data: list):
f"Predicted Spend for { today_month } 2024, ${total_predicted_spend}" f"Predicted Spend for { today_month } 2024, ${total_predicted_spend}"
) )
return {"response": response_data, "predicted_spend": predicted_spend} return {"response": response_data, "predicted_spend": predicted_spend}
# print(f"Date: {entry['date']}, Spend: {entry['spend']}, Response: {response.text}")
# _forecast_daily_cost(
# [
# {"date": "2022-01-01", "spend": 100},
# ]
# )

View file

@ -28,62 +28,6 @@ headers = {
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
import re
def mask_sensitive_info(error_message):
# Find the start of the key parameter
if isinstance(error_message, str):
key_index = error_message.find("key=")
else:
return error_message
# If key is found
if key_index != -1:
# Find the end of the key parameter (next & or end of string)
next_param = error_message.find("&", key_index)
if next_param == -1:
# If no more parameters, mask until the end of the string
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
else:
# Replace the key with redacted value, keeping other parameters
masked_message = (
error_message[: key_index + 4]
+ "[REDACTED_API_KEY]"
+ error_message[next_param:]
)
return masked_message
return error_message
class MaskedHTTPStatusError(httpx.HTTPStatusError):
def __init__(
self, original_error, message: Optional[str] = None, text: Optional[str] = None
):
# Create a new error with the masked URL
masked_url = mask_sensitive_info(str(original_error.request.url))
# Create a new error that looks like the original, but with a masked URL
super().__init__(
message=original_error.message,
request=httpx.Request(
method=original_error.request.method,
url=masked_url,
headers=original_error.request.headers,
content=original_error.request.content,
),
response=httpx.Response(
status_code=original_error.response.status_code,
content=original_error.response.content,
headers=original_error.response.headers,
),
)
self.message = message
self.text = text
class AsyncHTTPHandler: class AsyncHTTPHandler:
def __init__( def __init__(
@ -211,16 +155,13 @@ class AsyncHTTPHandler:
headers=headers, headers=headers,
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code)
if stream is True: if stream is True:
setattr(e, "message", await e.response.aread()) setattr(e, "message", await e.response.aread())
setattr(e, "text", await e.response.aread()) setattr(e, "text", await e.response.aread())
else: else:
setattr(e, "message", mask_sensitive_info(e.response.text)) setattr(e, "message", e.response.text)
setattr(e, "text", mask_sensitive_info(e.response.text)) setattr(e, "text", e.response.text)
setattr(e, "status_code", e.response.status_code)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e
@ -458,17 +399,11 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read()))
else:
error_text = mask_sensitive_info(e.response.text)
setattr(e, "message", error_text)
setattr(e, "text", error_text)
setattr(e, "status_code", e.response.status_code) setattr(e, "status_code", e.response.status_code)
if stream is True:
setattr(e, "message", e.response.read())
else:
setattr(e, "message", e.response.text)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e

View file

@ -1159,44 +1159,15 @@ def convert_to_anthropic_tool_result(
] ]
} }
""" """
anthropic_content: Union[ content_str: str = ""
str,
List[Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]],
] = ""
if isinstance(message["content"], str): if isinstance(message["content"], str):
anthropic_content = message["content"] content_str = message["content"]
elif isinstance(message["content"], List): elif isinstance(message["content"], List):
content_list = message["content"] content_list = message["content"]
anthropic_content_list: List[
Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]
] = []
for content in content_list: for content in content_list:
if content["type"] == "text": if content["type"] == "text":
anthropic_content_list.append( content_str += content["text"]
AnthropicMessagesToolResultContent(
type="text",
text=content["text"],
)
)
elif content["type"] == "image_url":
if isinstance(content["image_url"], str):
image_chunk = convert_to_anthropic_image_obj(content["image_url"])
else:
image_chunk = convert_to_anthropic_image_obj(
content["image_url"]["url"]
)
anthropic_content_list.append(
AnthropicMessagesImageParam(
type="image",
source=AnthropicContentParamSource(
type="base64",
media_type=image_chunk["media_type"],
data=image_chunk["data"],
),
)
)
anthropic_content = anthropic_content_list
anthropic_tool_result: Optional[AnthropicMessagesToolResultParam] = None anthropic_tool_result: Optional[AnthropicMessagesToolResultParam] = None
## PROMPT CACHING CHECK ## ## PROMPT CACHING CHECK ##
cache_control = message.get("cache_control", None) cache_control = message.get("cache_control", None)
@ -1207,14 +1178,14 @@ def convert_to_anthropic_tool_result(
# We can't determine from openai message format whether it's a successful or # We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template # error call result so default to the successful result template
anthropic_tool_result = AnthropicMessagesToolResultParam( anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content type="tool_result", tool_use_id=tool_call_id, content=content_str
) )
if message["role"] == "function": if message["role"] == "function":
function_message: ChatCompletionFunctionMessage = message function_message: ChatCompletionFunctionMessage = message
tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4()) tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
anthropic_tool_result = AnthropicMessagesToolResultParam( anthropic_tool_result = AnthropicMessagesToolResultParam(
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content type="tool_result", tool_use_id=tool_call_id, content=content_str
) )
if anthropic_tool_result is None: if anthropic_tool_result is None:

View file

@ -107,10 +107,6 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
return "image/png" return "image/png"
elif url.endswith(".webp"): elif url.endswith(".webp"):
return "image/webp" return "image/webp"
elif url.endswith(".mp4"):
return "video/mp4"
elif url.endswith(".pdf"):
return "application/pdf"
return None return None

View file

@ -15,22 +15,6 @@ model_list:
litellm_params: litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01 model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["public-openai-models"]
- model_name: openai/gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["private-openai-models"]
router_settings: router_settings:
routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2

View file

@ -2183,11 +2183,3 @@ PassThroughEndpointLoggingResultValues = Union[
class PassThroughEndpointLoggingTypedDict(TypedDict): class PassThroughEndpointLoggingTypedDict(TypedDict):
result: Optional[PassThroughEndpointLoggingResultValues] result: Optional[PassThroughEndpointLoggingResultValues]
kwargs: dict kwargs: dict
LiteLLM_ManagementEndpoint_MetadataFields = [
"model_rpm_limit",
"model_tpm_limit",
"guardrails",
"tags",
]

View file

@ -60,7 +60,6 @@ def common_checks( # noqa: PLR0915
global_proxy_spend: Optional[float], global_proxy_spend: Optional[float],
general_settings: dict, general_settings: dict,
route: str, route: str,
llm_router: Optional[litellm.Router],
) -> bool: ) -> bool:
""" """
Common checks across jwt + key-based auth. Common checks across jwt + key-based auth.
@ -98,12 +97,7 @@ def common_checks( # noqa: PLR0915
# this means the team has access to all models on the proxy # this means the team has access to all models on the proxy
pass pass
# check if the team model is an access_group # check if the team model is an access_group
elif ( elif model_in_access_group(_model, team_object.models) is True:
model_in_access_group(
model=_model, team_models=team_object.models, llm_router=llm_router
)
is True
):
pass pass
elif _model and "*" in _model: elif _model and "*" in _model:
pass pass
@ -379,33 +373,36 @@ async def get_end_user_object(
return None return None
def model_in_access_group( def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
) -> bool:
from collections import defaultdict from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
if team_models is None: if team_models is None:
return True return True
if model in team_models: if model in team_models:
return True return True
access_groups: dict[str, list[str]] = defaultdict(list) access_groups = defaultdict(list)
if llm_router: if llm_router:
access_groups = llm_router.get_model_access_groups(model_name=model) access_groups = llm_router.get_model_access_groups()
models_in_current_access_groups = []
if len(access_groups) > 0: # check if token contains any model access groups if len(access_groups) > 0: # check if token contains any model access groups
for idx, m in enumerate( for idx, m in enumerate(
team_models team_models
): # loop token models, if any of them are an access group add the access group ): # loop token models, if any of them are an access group add the access group
if m in access_groups: if m in access_groups:
return True # if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups # Filter out models that are access_groups
filtered_models = [m for m in team_models if m not in access_groups] filtered_models = [m for m in team_models if m not in access_groups]
filtered_models += models_in_current_access_groups
if model in filtered_models: if model in filtered_models:
return True return True
return False return False
@ -526,6 +523,10 @@ async def _cache_management_object(
proxy_logging_obj: Optional[ProxyLogging], proxy_logging_obj: Optional[ProxyLogging],
): ):
await user_api_key_cache.async_set_cache(key=key, value=value) await user_api_key_cache.async_set_cache(key=key, value=value)
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
key=key, value=value
)
async def _cache_team_object( async def _cache_team_object(
@ -877,10 +878,7 @@ async def get_org_object(
async def can_key_call_model( async def can_key_call_model(
model: str, model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
llm_model_list: Optional[list],
valid_token: UserAPIKeyAuth,
llm_router: Optional[litellm.Router],
) -> Literal[True]: ) -> Literal[True]:
""" """
Checks if token can call a given model Checks if token can call a given model
@ -900,29 +898,35 @@ async def can_key_call_model(
) )
from collections import defaultdict from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
access_groups = defaultdict(list) access_groups = defaultdict(list)
if llm_router: if llm_router:
access_groups = llm_router.get_model_access_groups(model_name=model) access_groups = llm_router.get_model_access_groups()
if ( models_in_current_access_groups = []
len(access_groups) > 0 and llm_router is not None if len(access_groups) > 0: # check if token contains any model access groups
): # check if token contains any model access groups
for idx, m in enumerate( for idx, m in enumerate(
valid_token.models valid_token.models
): # loop token models, if any of them are an access group add the access group ): # loop token models, if any of them are an access group add the access group
if m in access_groups: if m in access_groups:
return True # if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups # Filter out models that are access_groups
filtered_models = [m for m in valid_token.models if m not in access_groups] filtered_models = [m for m in valid_token.models if m not in access_groups]
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}") verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
all_model_access: bool = False all_model_access: bool = False
if ( if (
len(filtered_models) == 0 and len(valid_token.models) == 0 len(filtered_models) == 0
) or "*" in filtered_models: or "*" in filtered_models
or "openai/*" in filtered_models
):
all_model_access = True all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False: if model is not None and model not in filtered_models and all_model_access is False:

View file

@ -259,7 +259,6 @@ async def user_api_key_auth( # noqa: PLR0915
jwt_handler, jwt_handler,
litellm_proxy_admin_name, litellm_proxy_admin_name,
llm_model_list, llm_model_list,
llm_router,
master_key, master_key,
open_telemetry_logger, open_telemetry_logger,
prisma_client, prisma_client,
@ -543,7 +542,6 @@ async def user_api_key_auth( # noqa: PLR0915
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router,
) )
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
@ -907,7 +905,6 @@ async def user_api_key_auth( # noqa: PLR0915
model=model, model=model,
llm_model_list=llm_model_list, llm_model_list=llm_model_list,
valid_token=valid_token, valid_token=valid_token,
llm_router=llm_router,
) )
if fallback_models is not None: if fallback_models is not None:
@ -916,7 +913,6 @@ async def user_api_key_auth( # noqa: PLR0915
model=m, model=m,
llm_model_list=llm_model_list, llm_model_list=llm_model_list,
valid_token=valid_token, valid_token=valid_token,
llm_router=llm_router,
) )
# Check 2. If user_id for this token is in budget - done in common_checks() # Check 2. If user_id for this token is in budget - done in common_checks()
@ -1177,7 +1173,6 @@ async def user_api_key_auth( # noqa: PLR0915
general_settings=general_settings, general_settings=general_settings,
global_proxy_spend=global_proxy_spend, global_proxy_spend=global_proxy_spend,
route=route, route=route,
llm_router=llm_router,
) )
# Token passed all checks # Token passed all checks
if valid_token is None: if valid_token is None:

View file

@ -214,10 +214,10 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
prepared_request.url, prepared_request.url,
prepared_request.headers, prepared_request.headers,
) )
_json_data = json.dumps(request_data) # type: ignore
response = await self.async_handler.post( response = await self.async_handler.post(
url=prepared_request.url, url=prepared_request.url,
data=prepared_request.body, # type: ignore json=request_data, # type: ignore
headers=prepared_request.headers, # type: ignore headers=prepared_request.headers, # type: ignore
) )
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)

View file

@ -32,7 +32,6 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.key_management_endpoints import ( from litellm.proxy.management_endpoints.key_management_endpoints import (
duration_in_seconds, duration_in_seconds,
generate_key_helper_fn, generate_key_helper_fn,
prepare_metadata_fields,
) )
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
@ -43,7 +42,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter() router = APIRouter()
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict: def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
if "user_id" in data_json and data_json["user_id"] is None: if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4()) data_json["user_id"] = str(uuid.uuid4())
auto_create_key = data_json.pop("auto_create_key", True) auto_create_key = data_json.pop("auto_create_key", True)
@ -146,7 +145,7 @@ async def new_user(
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
data_json = data.json() # type: ignore data_json = data.json() # type: ignore
data_json = _update_internal_new_user_params(data_json, data) data_json = _update_internal_user_params(data_json, data)
response = await generate_key_helper_fn(request_type="user", **data_json) response = await generate_key_helper_fn(request_type="user", **data_json)
# Admin UI Logic # Admin UI Logic
@ -439,52 +438,6 @@ async def user_info( # noqa: PLR0915
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
non_default_values = {}
for k, v in data_json.items():
if (
v is not None
and v
not in (
[],
{},
0,
)
and k not in LiteLLM_ManagementEndpoint_MetadataFields
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
is_internal_user = False
if data.user_role == LitellmUserRoles.INTERNAL_USER:
is_internal_user = True
if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
if "max_budget" not in non_default_values:
if (
is_internal_user and litellm.max_internal_user_budget is not None
): # applies internal user limits, if user role updated
non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
return non_default_values
@router.post( @router.post(
"/user/update", "/user/update",
tags=["Internal User management"], tags=["Internal User management"],
@ -506,7 +459,6 @@ async def user_update(
"user_id": "test-litellm-user-4", "user_id": "test-litellm-user-4",
"user_role": "proxy_admin_viewer" "user_role": "proxy_admin_viewer"
}' }'
```
Parameters: Parameters:
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
@ -539,7 +491,7 @@ async def user_update(
- duration: Optional[str] - [NOT IMPLEMENTED]. - duration: Optional[str] - [NOT IMPLEMENTED].
- key_alias: Optional[str] - [NOT IMPLEMENTED]. - key_alias: Optional[str] - [NOT IMPLEMENTED].
```
""" """
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import prisma_client
@ -550,21 +502,46 @@ async def user_update(
raise Exception("Not connected to DB!") raise Exception("Not connected to DB!")
# get non default values for key # get non default values for key
non_default_values = _update_internal_user_params( non_default_values = {}
data_json=data_json, data=data for k, v in data_json.items():
) if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
existing_user_row = await prisma_client.get_data( is_internal_user = False
user_id=data.user_id, table_name="user", query_type="find_unique" if data.user_role == LitellmUserRoles.INTERNAL_USER:
) is_internal_user = True
existing_metadata = existing_user_row.metadata if existing_user_row else {} if "budget_duration" in non_default_values:
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
non_default_values["budget_reset_at"] = user_reset_at
non_default_values = prepare_metadata_fields( if "max_budget" not in non_default_values:
data=data, if (
non_default_values=non_default_values, is_internal_user and litellm.max_internal_user_budget is not None
existing_metadata=existing_metadata or {}, ): # applies internal user limits, if user role updated
) non_default_values["max_budget"] = litellm.max_internal_user_budget
if (
"budget_duration" not in non_default_values
): # applies internal user limits, if user role updated
if is_internal_user and litellm.internal_user_budget_duration is not None:
non_default_values["budget_duration"] = (
litellm.internal_user_budget_duration
)
duration_s = duration_in_seconds(
duration=non_default_values["budget_duration"]
)
user_reset_at = datetime.now(timezone.utc) + timedelta(
seconds=duration_s
)
non_default_values["budget_reset_at"] = user_reset_at
## ADD USER, IF NEW ## ## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data) verbose_proxy_logger.debug("/user/update: Received data = %s", data)

View file

@ -17,7 +17,7 @@ import secrets
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional, Tuple, cast from typing import List, Optional, Tuple
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -394,8 +394,7 @@ async def generate_key_fn( # noqa: PLR0915
} }
) )
_budget_id = getattr(_budget, "budget_id", None) _budget_id = getattr(_budget, "budget_id", None)
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore data_json = data.json() # type: ignore
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
if "max_budget" in data_json: if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None) data_json["key_max_budget"] = data_json.pop("max_budget", None)
@ -453,52 +452,12 @@ async def generate_key_fn( # noqa: PLR0915
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def prepare_metadata_fields(
data: BaseModel, non_default_values: dict, existing_metadata: dict
) -> dict:
"""
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
"""
if "metadata" not in non_default_values: # allow user to set metadata to none
non_default_values["metadata"] = existing_metadata.copy()
casted_metadata = cast(dict, non_default_values["metadata"])
data_json = data.model_dump(exclude_unset=True, exclude_none=True)
try:
for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = {}
casted_metadata[k].update(v)
if k == "tags" or k == "guardrails":
if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = []
seen = set(casted_metadata[k])
casted_metadata[k].extend(
x for x in v if x not in seen and not seen.add(x) # type: ignore
) # prevent duplicates from being added + maintain initial order
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
str(e)
)
)
non_default_values["metadata"] = casted_metadata
return non_default_values
def prepare_key_update_data( def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
): ):
data_json: dict = data.model_dump(exclude_unset=True) data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None) data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"] _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
non_default_values = {} non_default_values = {}
for k, v in data_json.items(): for k, v in data_json.items():
if k in _metadata_fields: if k in _metadata_fields:
@ -526,9 +485,21 @@ def prepare_key_update_data(
_metadata = existing_key_row.metadata or {} _metadata = existing_key_row.metadata or {}
non_default_values = prepare_metadata_fields( if data.model_tpm_limit:
data=data, non_default_values=non_default_values, existing_metadata=_metadata if "model_tpm_limit" not in _metadata:
) _metadata["model_tpm_limit"] = {}
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
non_default_values["metadata"] = _metadata
if data.model_rpm_limit:
if "model_rpm_limit" not in _metadata:
_metadata["model_rpm_limit"] = {}
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
non_default_values["metadata"] = _metadata
if data.guardrails:
_metadata["guardrails"] = data.guardrails
non_default_values["metadata"] = _metadata
return non_default_values return non_default_values
@ -953,11 +924,11 @@ async def generate_key_helper_fn( # noqa: PLR0915
request_type: Literal[ request_type: Literal[
"user", "key" "user", "key"
], # identifies if this request is from /user/new or /key/generate ], # identifies if this request is from /user/new or /key/generate
duration: Optional[str] = None, duration: Optional[str],
models: list = [], models: list,
aliases: dict = {}, aliases: dict,
config: dict = {}, config: dict,
spend: float = 0.0, spend: float,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None, key_budget_duration: Optional[str] = None,
budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable
@ -986,8 +957,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
allowed_cache_controls: Optional[list] = [], allowed_cache_controls: Optional[list] = [],
permissions: Optional[dict] = {}, permissions: Optional[dict] = {},
model_max_budget: Optional[dict] = {}, model_max_budget: Optional[dict] = {},
model_rpm_limit: Optional[dict] = None, model_rpm_limit: Optional[dict] = {},
model_tpm_limit: Optional[dict] = None, model_tpm_limit: Optional[dict] = {},
guardrails: Optional[list] = None, guardrails: Optional[list] = None,
teams: Optional[list] = None, teams: Optional[list] = None,
organization_id: Optional[str] = None, organization_id: Optional[str] = None,

View file

@ -1689,11 +1689,15 @@ class Router:
and potential_model_client is not None and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key and dynamic_api_key != potential_model_client.api_key
): ):
pass model_client = None
else: else:
pass model_client = potential_model_client
response = await litellm.aspeech(**data, **kwargs) response = await litellm.aspeech(
**data,
client=model_client,
**kwargs,
)
return response return response
except Exception as e: except Exception as e:
@ -4712,9 +4716,6 @@ class Router:
if hasattr(self, "model_list"): if hasattr(self, "model_list"):
returned_models: List[DeploymentTypedDict] = [] returned_models: List[DeploymentTypedDict] = []
if model_name is not None:
returned_models.extend(self._get_all_deployments(model_name=model_name))
if hasattr(self, "model_group_alias"): if hasattr(self, "model_group_alias"):
for model_alias, model_value in self.model_group_alias.items(): for model_alias, model_value in self.model_group_alias.items():
@ -4746,21 +4747,17 @@ class Router:
returned_models += self.model_list returned_models += self.model_list
return returned_models return returned_models
returned_models.extend(self._get_all_deployments(model_name=model_name))
return returned_models return returned_models
return None return None
def get_model_access_groups(self, model_name: Optional[str] = None): def get_model_access_groups(self):
"""
If model_name is provided, only return access groups for that model.
"""
from collections import defaultdict from collections import defaultdict
access_groups = defaultdict(list) access_groups = defaultdict(list)
model_list = self.get_model_list(model_name=model_name) if self.model_list:
if model_list: for m in self.model_list:
for m in model_list:
for group in m.get("model_info", {}).get("access_groups", []): for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"] model_name = m["model_name"]
access_groups[group].append(model_name) access_groups[group].append(model_name)

View file

@ -79,9 +79,7 @@ class PatternMatchRouter:
return new_deployments return new_deployments
def route( def route(self, request: Optional[str]) -> Optional[List[Dict]]:
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
) -> Optional[List[Dict]]:
""" """
Route a requested model to the corresponding llm deployments based on the regex pattern Route a requested model to the corresponding llm deployments based on the regex pattern
@ -91,26 +89,14 @@ class PatternMatchRouter:
Args: Args:
request: Optional[str] request: Optional[str]
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns: Returns:
Optional[List[Deployment]]: llm deployments Optional[List[Deployment]]: llm deployments
""" """
try: try:
if request is None: if request is None:
return None return None
regex_filtered_model_names = (
[self._pattern_to_regex(m) for m in filtered_model_names]
if filtered_model_names is not None
else []
)
for pattern, llm_deployments in self.patterns.items(): for pattern, llm_deployments in self.patterns.items():
if (
filtered_model_names is not None
and pattern not in regex_filtered_model_names
):
continue
pattern_match = re.match(pattern, request) pattern_match = re.match(pattern, request)
if pattern_match: if pattern_match:
return self._return_pattern_matched_deployments( return self._return_pattern_matched_deployments(

View file

@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
class DeploymentTypedDict(TypedDict, total=False): class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str] model_name: Required[str]
litellm_params: Required[LiteLLMParamsTypedDict] litellm_params: Required[LiteLLMParamsTypedDict]
model_info: dict model_info: Optional[dict]
SPECIAL_MODEL_INFO_PARAMS = [ SPECIAL_MODEL_INFO_PARAMS = [

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.53.2" version = "1.53.1"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.53.2" version = "1.53.1"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -1,6 +1,6 @@
# LITELLM PROXY DEPENDENCIES # # LITELLM PROXY DEPENDENCIES #
anyio==4.4.0 # openai + http req. anyio==4.4.0 # openai + http req.
openai==1.55.3 # openai req. openai==1.54.0 # openai req.
fastapi==0.111.0 # server dep fastapi==0.111.0 # server dep
backoff==2.2.1 # server dep backoff==2.2.1 # server dep
pyyaml==6.0.0 # server dep pyyaml==6.0.0 # server dep

View file

@ -1,3 +1 @@
Unit tests for individual LLM providers. More tests under `litellm/litellm/tests/*`.
Name of the test file is the name of the LLM provider - e.g. `test_openai.py` is for OpenAI.

File diff suppressed because one or more lines are too long

View file

@ -45,59 +45,81 @@ def test_map_azure_model_group(model_group_header, expected_model):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_ai_with_image_url(): @pytest.mark.respx
async def test_azure_ai_with_image_url(respx_mock: MockRouter):
""" """
Important test: Important test:
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
""" """
from openai import AsyncOpenAI
litellm.set_verbose = True litellm.set_verbose = True
client = AsyncOpenAI( # Mock response based on the actual API response
api_key="fake-api-key", mock_response = {
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com", "id": "cmpl-53860ea1efa24d2883555bfec13d2254",
) "choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": "The image displays a graphic with the text 'LiteLLM' in black",
"role": "assistant",
"refusal": None,
"audio": None,
"function_call": None,
"tool_calls": None,
},
}
],
"created": 1731801937,
"model": "phi35-vision-instruct",
"object": "chat.completion",
"usage": {
"completion_tokens": 69,
"prompt_tokens": 617,
"total_tokens": 686,
"completion_tokens_details": None,
"prompt_tokens_details": None,
},
}
with patch.object( # Mock the API request
client.chat.completions.with_raw_response, "create" mock_request = respx_mock.post(
) as mock_client: "https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com"
try: ).mock(return_value=httpx.Response(200, json=mock_response))
await litellm.acompletion(
model="azure_ai/Phi-3-5-vision-instruct-dcvov", response = await litellm.acompletion(
api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com", model="azure_ai/Phi-3-5-vision-instruct-dcvov",
messages=[ api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
messages=[
{
"role": "user",
"content": [
{ {
"role": "user", "type": "text",
"content": [ "text": "What is in this image?",
{ },
"type": "text", {
"text": "What is in this image?", "type": "image_url",
}, "image_url": {
{ "url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
"type": "image_url", },
"image_url": {
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
},
},
],
}, },
], ],
api_key="fake-api-key", },
client=client, ],
) api_key="fake-api-key",
except Exception as e: )
traceback.print_exc()
print(f"Error: {e}")
# Verify the request was made # Verify the request was made
mock_client.assert_called_once() assert mock_request.called
# Check the request body # Check the request body
request_body = mock_client.call_args.kwargs request_body = json.loads(mock_request.calls[0].request.content)
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov" assert request_body == {
assert request_body["messages"] == [ "model": "Phi-3-5-vision-instruct-dcvov",
"messages": [
{ {
"role": "user", "role": "user",
"content": [ "content": [
@ -110,4 +132,7 @@ async def test_azure_ai_with_image_url():
}, },
], ],
} }
] ],
}
print(f"response: {response}")

View file

@ -13,7 +13,6 @@ load_dotenv()
import httpx import httpx
import pytest import pytest
from respx import MockRouter from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
import litellm import litellm
from litellm import Choices, Message, ModelResponse from litellm import Choices, Message, ModelResponse
@ -42,58 +41,56 @@ def return_mocked_response(model: str):
"bedrock/mistral.mistral-large-2407-v1:0", "bedrock/mistral.mistral-large-2407-v1:0",
], ],
) )
@pytest.mark.respx
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_bedrock_max_completion_tokens(model: str): async def test_bedrock_max_completion_tokens(model: str, respx_mock: MockRouter):
""" """
Tests that: Tests that:
- max_completion_tokens is passed as max_tokens to bedrock models - max_completion_tokens is passed as max_tokens to bedrock models
""" """
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.set_verbose = True litellm.set_verbose = True
client = AsyncHTTPHandler()
mock_response = return_mocked_response(model) mock_response = return_mocked_response(model)
_model = model.split("/")[1] _model = model.split("/")[1]
print("\n\nmock_response: ", mock_response) print("\n\nmock_response: ", mock_response)
url = f"https://bedrock-runtime.us-west-2.amazonaws.com/model/{_model}/converse"
mock_request = respx_mock.post(url).mock(
return_value=httpx.Response(200, json=mock_response)
)
with patch.object(client, "post") as mock_client: response = await litellm.acompletion(
try: model=model,
response = await litellm.acompletion( max_completion_tokens=10,
model=model, messages=[{"role": "user", "content": "Hello!"}],
max_completion_tokens=10, )
messages=[{"role": "user", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once() assert mock_request.called
request_body = json.loads(mock_client.call_args.kwargs["data"]) request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body) print("request_body: ", request_body)
assert request_body == { assert request_body == {
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}], "messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
"additionalModelRequestFields": {}, "additionalModelRequestFields": {},
"system": [], "system": [],
"inferenceConfig": {"maxTokens": 10}, "inferenceConfig": {"maxTokens": 10},
} }
print(f"response: {response}")
assert isinstance(response, ModelResponse)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"], ["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229,"],
) )
@pytest.mark.respx
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_anthropic_api_max_completion_tokens(model: str): async def test_anthropic_api_max_completion_tokens(model: str, respx_mock: MockRouter):
""" """
Tests that: Tests that:
- max_completion_tokens is passed as max_tokens to anthropic models - max_completion_tokens is passed as max_tokens to anthropic models
""" """
litellm.set_verbose = True litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler
mock_response = { mock_response = {
"content": [{"text": "Hi! My name is Claude.", "type": "text"}], "content": [{"text": "Hi! My name is Claude.", "type": "text"}],
@ -106,32 +103,30 @@ async def test_anthropic_api_max_completion_tokens(model: str):
"usage": {"input_tokens": 2095, "output_tokens": 503}, "usage": {"input_tokens": 2095, "output_tokens": 503},
} }
client = HTTPHandler()
print("\n\nmock_response: ", mock_response) print("\n\nmock_response: ", mock_response)
url = f"https://api.anthropic.com/v1/messages"
mock_request = respx_mock.post(url).mock(
return_value=httpx.Response(200, json=mock_response)
)
with patch.object(client, "post") as mock_client: response = await litellm.acompletion(
try: model=model,
response = await litellm.acompletion( max_completion_tokens=10,
model=model, messages=[{"role": "user", "content": "Hello!"}],
max_completion_tokens=10, )
messages=[{"role": "user", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs["json"]
print("request_body: ", request_body) assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
assert request_body == { print("request_body: ", request_body)
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]} assert request_body == {
], "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}],
"max_tokens": 10, "max_tokens": 10,
"model": model.split("/")[-1], "model": model.split("/")[-1],
} }
print(f"response: {response}")
assert isinstance(response, ModelResponse)
def test_all_model_configs(): def test_all_model_configs():

View file

@ -12,78 +12,95 @@ sys.path.insert(
import httpx import httpx
import pytest import pytest
from respx import MockRouter from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
import litellm import litellm
from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
from litellm import completion from litellm import completion
def test_completion_nvidia_nim(): @pytest.mark.respx
from openai import OpenAI def test_completion_nvidia_nim(respx_mock: MockRouter):
litellm.set_verbose = True litellm.set_verbose = True
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="databricks/dbrx-instruct",
)
model_name = "nvidia_nim/databricks/dbrx-instruct" model_name = "nvidia_nim/databricks/dbrx-instruct"
client = OpenAI(
api_key="fake-api-key",
)
with patch.object( mock_request = respx_mock.post(
client.chat.completions.with_raw_response, "create" "https://integrate.api.nvidia.com/v1/chat/completions"
) as mock_client: ).mock(return_value=httpx.Response(200, json=mock_response.dict()))
try: try:
completion( response = completion(
model=model_name, model=model_name,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?", "content": "What's the weather like in Boston today in Fahrenheit?",
} }
], ],
presence_penalty=0.5, presence_penalty=0.5,
frequency_penalty=0.1, frequency_penalty=0.1,
client=client, )
)
except Exception as e:
print(e)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response)
assert response.choices[0].message.content is not None
assert len(response.choices[0].message.content) > 0
mock_client.assert_called_once() assert mock_request.called
request_body = mock_client.call_args.kwargs request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body) print("request_body: ", request_body)
assert request_body["messages"] == [ assert request_body == {
{ "messages": [
"role": "user", {
"content": "What's the weather like in Boston today in Fahrenheit?", "role": "user",
}, "content": "What's the weather like in Boston today in Fahrenheit?",
] }
assert request_body["model"] == "databricks/dbrx-instruct" ],
assert request_body["frequency_penalty"] == 0.1 "model": "databricks/dbrx-instruct",
assert request_body["presence_penalty"] == 0.5 "frequency_penalty": 0.1,
"presence_penalty": 0.5,
}
except litellm.exceptions.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_embedding_nvidia_nim(): def test_embedding_nvidia_nim(respx_mock: MockRouter):
litellm.set_verbose = True litellm.set_verbose = True
from openai import OpenAI mock_response = EmbeddingResponse(
model="nvidia_nim/databricks/dbrx-instruct",
client = OpenAI( data=[
api_key="fake-api-key", {
"embedding": [0.1, 0.2, 0.3],
"index": 0,
}
],
usage=Usage(
prompt_tokens=10,
completion_tokens=0,
total_tokens=10,
),
) )
with patch.object(client.embeddings.with_raw_response, "create") as mock_client: mock_request = respx_mock.post(
try: "https://integrate.api.nvidia.com/v1/embeddings"
litellm.embedding( ).mock(return_value=httpx.Response(200, json=mock_response.dict()))
model="nvidia_nim/nvidia/nv-embedqa-e5-v5", response = litellm.embedding(
input="What is the meaning of life?", model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
input_type="passage", input="What is the meaning of life?",
client=client, input_type="passage",
) )
except Exception as e: assert mock_request.called
print(e) request_body = json.loads(mock_request.calls[0].request.content)
mock_client.assert_called_once() print("request_body: ", request_body)
request_body = mock_client.call_args.kwargs assert request_body == {
print("request_body: ", request_body) "input": "What is the meaning of life?",
assert request_body["input"] == "What is the meaning of life?" "model": "nvidia/nv-embedqa-e5-v5",
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5" "input_type": "passage",
assert request_body["extra_body"]["input_type"] == "passage" "encoding_format": "base64",
}

View file

@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -18,75 +18,87 @@ from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_o1_handle_system_role(): @pytest.mark.respx
async def test_o1_handle_system_role(respx_mock: MockRouter):
""" """
Tests that: Tests that:
- max_tokens is translated to 'max_completion_tokens' - max_tokens is translated to 'max_completion_tokens'
- role 'system' is translated to 'user' - role 'system' is translated to 'user'
""" """
from openai import AsyncOpenAI
litellm.set_verbose = True litellm.set_verbose = True
client = AsyncOpenAI(api_key="fake-api-key") mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="o1-preview",
)
with patch.object( mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
client.chat.completions.with_raw_response, "create" return_value=httpx.Response(200, json=mock_response.dict())
) as mock_client: )
try:
await litellm.acompletion(
model="o1-preview",
max_tokens=10,
messages=[{"role": "system", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once() response = await litellm.acompletion(
request_body = mock_client.call_args.kwargs model="o1-preview",
max_tokens=10,
messages=[{"role": "system", "content": "Hello!"}],
)
print("request_body: ", request_body) assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
assert request_body["model"] == "o1-preview" print("request_body: ", request_body)
assert request_body["max_completion_tokens"] == 10
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}] assert request_body == {
"model": "o1-preview",
"max_completion_tokens": 10,
"messages": [{"role": "user", "content": "Hello!"}],
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"]) @pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
async def test_o1_max_completion_tokens(model: str): async def test_o1_max_completion_tokens(respx_mock: MockRouter, model: str):
""" """
Tests that: Tests that:
- max_completion_tokens is passed directly to OpenAI chat completion models - max_completion_tokens is passed directly to OpenAI chat completion models
""" """
from openai import AsyncOpenAI
litellm.set_verbose = True litellm.set_verbose = True
client = AsyncOpenAI(api_key="fake-api-key") mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model=model,
)
with patch.object( mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
client.chat.completions.with_raw_response, "create" return_value=httpx.Response(200, json=mock_response.dict())
) as mock_client: )
try:
await litellm.acompletion(
model=model,
max_completion_tokens=10,
messages=[{"role": "user", "content": "Hello!"}],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once() response = await litellm.acompletion(
request_body = mock_client.call_args.kwargs model=model,
max_completion_tokens=10,
messages=[{"role": "user", "content": "Hello!"}],
)
print("request_body: ", request_body) assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
assert request_body["model"] == model print("request_body: ", request_body)
assert request_body["max_completion_tokens"] == 10
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}] assert request_body == {
"model": model,
"max_completion_tokens": 10,
"messages": [{"role": "user", "content": "Hello!"}],
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)
def test_litellm_responses(): def test_litellm_responses():

View file

@ -2,7 +2,7 @@ import json
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -63,7 +63,8 @@ def test_openai_prediction_param():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_prediction_param_mock(): @pytest.mark.respx
async def test_openai_prediction_param_mock(respx_mock: MockRouter):
""" """
Tests that prediction parameter is correctly passed to the API Tests that prediction parameter is correctly passed to the API
""" """
@ -91,36 +92,60 @@ async def test_openai_prediction_param_mock():
public string Username { get; set; } public string Username { get; set; }
} }
""" """
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key="fake-api-key") mock_response = ModelResponse(
id="chatcmpl-AQ5RmV8GvVSRxEcDxnuXlQnsibiY9",
with patch.object( choices=[
client.chat.completions.with_raw_response, "create" Choices(
) as mock_client: message=Message(
try: content=code.replace("Username", "Email").replace(
await litellm.acompletion( "username", "email"
model="gpt-4o-mini", ),
messages=[ role="assistant",
{ )
"role": "user",
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
},
{"role": "user", "content": code},
],
prediction={"type": "content", "content": code},
client=client,
) )
except Exception as e: ],
print(f"Error: {e}") created=int(datetime.now().timestamp()),
model="gpt-4o-mini-2024-07-18",
usage={
"completion_tokens": 207,
"prompt_tokens": 175,
"total_tokens": 382,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 80,
},
},
)
mock_client.assert_called_once() mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
request_body = mock_client.call_args.kwargs return_value=httpx.Response(200, json=mock_response.dict())
)
# Verify the request contains the prediction parameter completion = await litellm.acompletion(
assert "prediction" in request_body model="gpt-4o-mini",
# verify prediction is correctly sent to the API messages=[
assert request_body["prediction"] == {"type": "content", "content": code} {
"role": "user",
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
},
{"role": "user", "content": code},
],
prediction={"type": "content", "content": code},
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
# Verify the request contains the prediction parameter
assert "prediction" in request_body
# verify prediction is correctly sent to the API
assert request_body["prediction"] == {"type": "content", "content": code}
# Verify the completion tokens details
assert completion.usage.completion_tokens_details.accepted_prediction_tokens == 0
assert completion.usage.completion_tokens_details.rejected_prediction_tokens == 80
@pytest.mark.asyncio @pytest.mark.asyncio
@ -198,73 +223,3 @@ async def test_openai_prediction_param_with_caching():
) )
assert completion_response_3.id != completion_response_1.id assert completion_response_3.id != completion_response_1.id
@pytest.mark.asyncio()
async def test_vision_with_custom_model():
"""
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
"""
import base64
import requests
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key="fake-api-key")
litellm.set_verbose = True
api_base = "https://my-custom.api.openai.com"
# Fetch and encode a test image
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
try:
response = await litellm.acompletion(
model="openai/my-custom-model",
max_tokens=10,
api_base=api_base, # use the mock api
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
],
}
],
client=client,
)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs
print("request_body: ", request_body)
assert request_body["messages"] == [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
},
},
],
},
]
assert request_body["model"] == "my-custom-model"
assert request_body["max_tokens"] == 10

View file

@ -0,0 +1,94 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import httpx
import pytest
from respx import MockRouter
import litellm
from litellm import Choices, Message, ModelResponse
@pytest.mark.asyncio()
@pytest.mark.respx
async def test_vision_with_custom_model(respx_mock: MockRouter):
"""
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
"""
import base64
import requests
litellm.set_verbose = True
api_base = "https://my-custom.api.openai.com"
# Fetch and encode a test image
url = "https://dummyimage.com/100/100/fff&text=Test+image"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_image = f"data:image/png;base64,{encoded_file}"
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="my-custom-model",
)
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
return_value=httpx.Response(200, json=mock_response.dict())
)
response = await litellm.acompletion(
model="openai/my-custom-model",
max_tokens=10,
api_base=api_base, # use the mock api
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": base64_image},
},
],
}
],
)
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("request_body: ", request_body)
assert request_body == {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
},
},
],
}
],
"model": "my-custom-model",
"max_tokens": 10,
}
print(f"response: {response}")
assert isinstance(response, ModelResponse)

View file

@ -6,7 +6,6 @@ from unittest.mock import AsyncMock
import pytest import pytest
import httpx import httpx
from respx import MockRouter from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -69,16 +68,13 @@ def test_convert_dict_to_text_completion_response():
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}] assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
@pytest.mark.skip(
reason="need to migrate huggingface to support httpx client being passed in"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.respx @pytest.mark.respx
async def test_huggingface_text_completion_logprobs(): async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
"""Test text completion with Hugging Face, focusing on logprobs structure""" """Test text completion with Hugging Face, focusing on logprobs structure"""
litellm.set_verbose = True litellm.set_verbose = True
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
# Mock the raw response from Hugging Face
mock_response = [ mock_response = [
{ {
"generated_text": ",\n\nI have a question...", # truncated for brevity "generated_text": ",\n\nI have a question...", # truncated for brevity
@ -95,48 +91,46 @@ async def test_huggingface_text_completion_logprobs():
} }
] ]
return_val = AsyncMock() # Mock the API request
mock_request = respx_mock.post(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
).mock(return_value=httpx.Response(200, json=mock_response))
return_val.json.return_value = mock_response response = await litellm.atext_completion(
model="huggingface/mistralai/Mistral-7B-v0.1",
prompt="good morning",
)
client = AsyncHTTPHandler() # Verify the request
with patch.object(client, "post", return_value=return_val) as mock_post: assert mock_request.called
response = await litellm.atext_completion( request_body = json.loads(mock_request.calls[0].request.content)
model="huggingface/mistralai/Mistral-7B-v0.1", assert request_body == {
prompt="good morning", "inputs": "good morning",
client=client, "parameters": {"details": True, "return_full_text": False},
) "stream": False,
}
# Verify the request print("response=", response)
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
assert request_body == {
"inputs": "good morning",
"parameters": {"details": True, "return_full_text": False},
"stream": False,
}
print("response=", response) # Verify response structure
assert isinstance(response, TextCompletionResponse)
assert response.object == "text_completion"
assert response.model == "mistralai/Mistral-7B-v0.1"
# Verify response structure # Verify logprobs structure
assert isinstance(response, TextCompletionResponse) choice = response.choices[0]
assert response.object == "text_completion" assert choice.finish_reason == "length"
assert response.model == "mistralai/Mistral-7B-v0.1" assert choice.index == 0
assert isinstance(choice.logprobs.tokens, list)
assert isinstance(choice.logprobs.token_logprobs, list)
assert isinstance(choice.logprobs.text_offset, list)
assert isinstance(choice.logprobs.top_logprobs, list)
assert choice.logprobs.tokens == [",", "\n"]
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
assert choice.logprobs.text_offset == [0, 1]
assert choice.logprobs.top_logprobs == [{}, {}]
# Verify logprobs structure # Verify usage
choice = response.choices[0] assert response.usage["completion_tokens"] > 0
assert choice.finish_reason == "length" assert response.usage["prompt_tokens"] > 0
assert choice.index == 0 assert response.usage["total_tokens"] > 0
assert isinstance(choice.logprobs.tokens, list)
assert isinstance(choice.logprobs.token_logprobs, list)
assert isinstance(choice.logprobs.text_offset, list)
assert isinstance(choice.logprobs.top_logprobs, list)
assert choice.logprobs.tokens == [",", "\n"]
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
assert choice.logprobs.text_offset == [0, 1]
assert choice.logprobs.top_logprobs == [{}, {}]
# Verify usage
assert response.usage["completion_tokens"] > 0
assert response.usage["prompt_tokens"] > 0
assert response.usage["total_tokens"] > 0

View file

@ -1146,21 +1146,6 @@ def test_process_gemini_image():
mime_type="image/png", file_uri="https://example.com/image.png" mime_type="image/png", file_uri="https://example.com/image.png"
) )
# Test HTTPS VIDEO URL
https_result = _process_gemini_image("https://cloud-samples-data/video/animals.mp4")
print("https_result PNG", https_result)
assert https_result["file_data"] == FileDataType(
mime_type="video/mp4", file_uri="https://cloud-samples-data/video/animals.mp4"
)
# Test HTTPS PDF URL
https_result = _process_gemini_image("https://cloud-samples-data/pdf/animals.pdf")
print("https_result PDF", https_result)
assert https_result["file_data"] == FileDataType(
mime_type="application/pdf",
file_uri="https://cloud-samples-data/pdf/animals.pdf",
)
# Test base64 image # Test base64 image
base64_image = "data:image/jpeg;base64,/9j/4AAQSkZJRg..." base64_image = "data:image/jpeg;base64,/9j/4AAQSkZJRg..."
base64_result = _process_gemini_image(base64_image) base64_result = _process_gemini_image(base64_image)

View file

@ -95,107 +95,3 @@ async def test_handle_failed_db_connection():
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info) print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
assert str(exc_info.value) == "Failed to connect to DB" assert str(exc_info.value) == "Failed to connect to DB"
@pytest.mark.parametrize(
"model, expect_to_work",
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
)
@pytest.mark.asyncio
async def test_can_key_call_model(model, expect_to_work):
"""
If wildcard model + specific model is used, choose the specific model settings
"""
from litellm.proxy.auth.auth_checks import can_key_call_model
from fastapi import HTTPException
llm_model_list = [
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
"access_groups": ["public-openai-models"],
},
},
{
"model_name": "openai/gpt-4o",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "test-api-key",
},
"model_info": {
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
"db_model": False,
"access_groups": ["private-openai-models"],
},
},
]
router = litellm.Router(model_list=llm_model_list)
args = {
"model": model,
"llm_model_list": llm_model_list,
"valid_token": UserAPIKeyAuth(
models=["public-openai-models"],
),
"llm_router": router,
}
if expect_to_work:
await can_key_call_model(**args)
else:
with pytest.raises(Exception) as e:
await can_key_call_model(**args)
print(e)
@pytest.mark.parametrize(
"model, expect_to_work",
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
)
@pytest.mark.asyncio
async def test_can_team_call_model(model, expect_to_work):
from litellm.proxy.auth.auth_checks import model_in_access_group
from fastapi import HTTPException
llm_model_list = [
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "test-api-key",
},
"model_info": {
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
"db_model": False,
"access_groups": ["public-openai-models"],
},
},
{
"model_name": "openai/gpt-4o",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "test-api-key",
},
"model_info": {
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
"db_model": False,
"access_groups": ["private-openai-models"],
},
},
]
router = litellm.Router(model_list=llm_model_list)
args = {
"model": model,
"team_models": ["public-openai-models"],
"llm_router": router,
}
if expect_to_work:
assert model_in_access_group(**args)
else:
assert not model_in_access_group(**args)

View file

@ -33,7 +33,7 @@ from litellm.router import Router
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.respx() @pytest.mark.respx()
async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter): async def test_azure_tenant_id_auth(respx_mock: MockRouter):
""" """
Tests when we set tenant_id, client_id, client_secret they don't get sent with the request Tests when we set tenant_id, client_id, client_secret they don't get sent with the request

View file

@ -1,128 +1,128 @@
# #### What this tests #### #### What this tests ####
# # This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk. # This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk.
# import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
# from datetime import datetime from datetime import datetime
# import pytest import pytest
# sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
# import openai, litellm, uuid import openai, litellm, uuid
# from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
# client = AsyncAzureOpenAI( client = AsyncAzureOpenAI(
# api_key=os.getenv("AZURE_API_KEY"), api_key=os.getenv("AZURE_API_KEY"),
# azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
# api_version=os.getenv("AZURE_API_VERSION"), api_version=os.getenv("AZURE_API_VERSION"),
# ) )
# model_list = [ model_list = [
# { {
# "model_name": "azure-test", "model_name": "azure-test",
# "litellm_params": { "litellm_params": {
# "model": "azure/chatgpt-v-2", "model": "azure/chatgpt-v-2",
# "api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
# "api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
# "api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
# }, },
# } }
# ] ]
# router = litellm.Router(model_list=model_list) # type: ignore router = litellm.Router(model_list=model_list) # type: ignore
# async def _openai_completion(): async def _openai_completion():
# try: try:
# start_time = time.time() start_time = time.time()
# response = await client.chat.completions.create( response = await client.chat.completions.create(
# model="chatgpt-v-2", model="chatgpt-v-2",
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
# stream=True, stream=True,
# ) )
# time_to_first_token = None time_to_first_token = None
# first_token_ts = None first_token_ts = None
# init_chunk = None init_chunk = None
# async for chunk in response: async for chunk in response:
# if ( if (
# time_to_first_token is None time_to_first_token is None
# and len(chunk.choices) > 0 and len(chunk.choices) > 0
# and chunk.choices[0].delta.content is not None and chunk.choices[0].delta.content is not None
# ): ):
# first_token_ts = time.time() first_token_ts = time.time()
# time_to_first_token = first_token_ts - start_time time_to_first_token = first_token_ts - start_time
# init_chunk = chunk init_chunk = chunk
# end_time = time.time() end_time = time.time()
# print( print(
# "OpenAI Call: ", "OpenAI Call: ",
# init_chunk, init_chunk,
# start_time, start_time,
# first_token_ts, first_token_ts,
# time_to_first_token, time_to_first_token,
# end_time, end_time,
# ) )
# return time_to_first_token return time_to_first_token
# except Exception as e: except Exception as e:
# print(e) print(e)
# return None return None
# async def _router_completion(): async def _router_completion():
# try: try:
# start_time = time.time() start_time = time.time()
# response = await router.acompletion( response = await router.acompletion(
# model="azure-test", model="azure-test",
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
# stream=True, stream=True,
# ) )
# time_to_first_token = None time_to_first_token = None
# first_token_ts = None first_token_ts = None
# init_chunk = None init_chunk = None
# async for chunk in response: async for chunk in response:
# if ( if (
# time_to_first_token is None time_to_first_token is None
# and len(chunk.choices) > 0 and len(chunk.choices) > 0
# and chunk.choices[0].delta.content is not None and chunk.choices[0].delta.content is not None
# ): ):
# first_token_ts = time.time() first_token_ts = time.time()
# time_to_first_token = first_token_ts - start_time time_to_first_token = first_token_ts - start_time
# init_chunk = chunk init_chunk = chunk
# end_time = time.time() end_time = time.time()
# print( print(
# "Router Call: ", "Router Call: ",
# init_chunk, init_chunk,
# start_time, start_time,
# first_token_ts, first_token_ts,
# time_to_first_token, time_to_first_token,
# end_time - first_token_ts, end_time - first_token_ts,
# ) )
# return time_to_first_token return time_to_first_token
# except Exception as e: except Exception as e:
# print(e) print(e)
# return None return None
# async def test_azure_completion_streaming(): async def test_azure_completion_streaming():
# """ """
# Test azure streaming call - measure on time to first (non-null) token. Test azure streaming call - measure on time to first (non-null) token.
# """ """
# n = 3 # Number of concurrent tasks n = 3 # Number of concurrent tasks
# ## OPENAI AVG. TIME ## OPENAI AVG. TIME
# tasks = [_openai_completion() for _ in range(n)] tasks = [_openai_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None] successful_completions = [c for c in chat_completions if c is not None]
# total_time = 0 total_time = 0
# for item in successful_completions: for item in successful_completions:
# total_time += item total_time += item
# avg_openai_time = total_time / 3 avg_openai_time = total_time / 3
# ## ROUTER AVG. TIME ## ROUTER AVG. TIME
# tasks = [_router_completion() for _ in range(n)] tasks = [_router_completion() for _ in range(n)]
# chat_completions = await asyncio.gather(*tasks) chat_completions = await asyncio.gather(*tasks)
# successful_completions = [c for c in chat_completions if c is not None] successful_completions = [c for c in chat_completions if c is not None]
# total_time = 0 total_time = 0
# for item in successful_completions: for item in successful_completions:
# total_time += item total_time += item
# avg_router_time = total_time / 3 avg_router_time = total_time / 3
# ## COMPARE ## COMPARE
# print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}") print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
# assert avg_router_time < avg_openai_time + 0.5 assert avg_router_time < avg_openai_time + 0.5
# # asyncio.run(test_azure_completion_streaming()) # asyncio.run(test_azure_completion_streaming())

View file

@ -1146,9 +1146,7 @@ async def test_exception_with_headers_httpx(
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
exception_raised = True exception_raised = True
assert ( assert e.litellm_response_headers is not None
e.litellm_response_headers is not None
), "litellm_response_headers is None"
print("e.litellm_response_headers", e.litellm_response_headers) print("e.litellm_response_headers", e.litellm_response_headers)
assert int(e.litellm_response_headers["retry-after"]) == cooldown_time assert int(e.litellm_response_headers["retry-after"]) == cooldown_time

View file

@ -20,6 +20,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import APIConnectionError, Router from litellm import APIConnectionError, Router
from unittest.mock import ANY
async def test_router_init(): async def test_router_init():
@ -213,3 +214,48 @@ def test_router_init_azure_service_principal_with_secret_with_environment_variab
# asyncio.run(test_router_init()) # asyncio.run(test_router_init())
@pytest.mark.asyncio
async def test_audio_speech_router():
"""
Test that router uses OpenAI/Azure OpenAI Client initialized during init for litellm.aspeech
"""
from litellm import Router
litellm.set_verbose = True
model_list = [
{
"model_name": "tts",
"litellm_params": {
"model": "azure/azure-tts",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY"),
},
},
]
_router = Router(model_list=model_list)
expected_openai_client = _router._get_client(
deployment=_router.model_list[0],
kwargs={},
client_type="async",
)
with patch("litellm.aspeech") as mock_aspeech:
await _router.aspeech(
model="tts",
voice="alloy",
input="the quick brown fox jumped over the lazy dogs",
)
print(
"litellm.aspeech was called with kwargs = ", mock_aspeech.call_args.kwargs
)
# Get the actual client that was passed
client_passed_in_request = mock_aspeech.call_args.kwargs["client"]
assert client_passed_in_request == expected_openai_client

View file

@ -212,7 +212,7 @@ async def test_bedrock_guardrail_triggered():
session, session,
"sk-1234", "sk-1234",
model="fake-openai-endpoint", model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hello do you like coffee?"}], messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
guardrails=["bedrock-pre-guard"], guardrails=["bedrock-pre-guard"],
) )
pytest.fail("Should have thrown an exception") pytest.fail("Should have thrown an exception")

View file

@ -693,47 +693,3 @@ def test_personal_key_generation_check():
), ),
data=GenerateKeyRequest(), data=GenerateKeyRequest(),
) )
def test_prepare_metadata_fields():
from litellm.proxy.management_endpoints.key_management_endpoints import (
prepare_metadata_fields,
)
new_metadata = {"test": "new"}
old_metadata = {"test": "test"}
args = {
"data": UpdateKeyRequest(
key_alias=None,
duration=None,
models=[],
spend=None,
max_budget=None,
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata=new_metadata,
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
allowed_cache_controls=[],
soft_budget=None,
config={},
permissions={},
model_max_budget={},
send_invite_email=None,
model_rpm_limit=None,
model_tpm_limit=None,
guardrails=None,
blocked=None,
aliases={},
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
tags=None,
),
"non_default_values": {"metadata": new_metadata},
"existing_metadata": {"tags": None, **old_metadata},
}
non_default_values = prepare_metadata_fields(**args)
assert non_default_values == {"metadata": new_metadata}

View file

@ -1345,8 +1345,17 @@ def test_generate_and_update_key(prisma_client):
) )
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
print(
"days between now and budget_reset_at",
(budget_reset_at - current_time).days,
)
# assert budget_reset_at is 30 days from now # assert budget_reset_at is 30 days from now
assert 31 >= (budget_reset_at - current_time).days >= 29 assert (
abs(
(budget_reset_at - current_time).total_seconds() - 30 * 24 * 60 * 60
)
<= 10
)
# cleanup - delete key # cleanup - delete key
delete_key_request = KeyRequest(keys=[generated_key]) delete_key_request = KeyRequest(keys=[generated_key])
@ -2917,6 +2926,7 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
"team": "litellm-team3", "team": "litellm-team3",
"model_tpm_limit": {"gpt-4": 100}, "model_tpm_limit": {"gpt-4": 100},
"model_rpm_limit": {"gpt-4": 2}, "model_rpm_limit": {"gpt-4": 2},
"tags": None,
} }
# Update model tpm_limit and rpm_limit # Update model tpm_limit and rpm_limit
@ -2940,6 +2950,7 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
"team": "litellm-team3", "team": "litellm-team3",
"model_tpm_limit": {"gpt-4": 200}, "model_tpm_limit": {"gpt-4": 200},
"model_rpm_limit": {"gpt-4": 3}, "model_rpm_limit": {"gpt-4": 3},
"tags": None,
} }
@ -2979,6 +2990,7 @@ async def test_generate_key_with_guardrails(prisma_client):
assert result["info"]["metadata"] == { assert result["info"]["metadata"] == {
"team": "litellm-team3", "team": "litellm-team3",
"guardrails": ["aporia-pre-call"], "guardrails": ["aporia-pre-call"],
"tags": None,
} }
# Update model tpm_limit and rpm_limit # Update model tpm_limit and rpm_limit
@ -3000,6 +3012,7 @@ async def test_generate_key_with_guardrails(prisma_client):
assert result["info"]["metadata"] == { assert result["info"]["metadata"] == {
"team": "litellm-team3", "team": "litellm-team3",
"guardrails": ["aporia-pre-call", "aporia-post-call"], "guardrails": ["aporia-pre-call", "aporia-post-call"],
"tags": None,
} }

View file

@ -444,7 +444,7 @@ def test_foward_litellm_user_info_to_backend_llm_call():
def test_update_internal_user_params(): def test_update_internal_user_params():
from litellm.proxy.management_endpoints.internal_user_endpoints import ( from litellm.proxy.management_endpoints.internal_user_endpoints import (
_update_internal_new_user_params, _update_internal_user_params,
) )
from litellm.proxy._types import NewUserRequest from litellm.proxy._types import NewUserRequest
@ -456,7 +456,7 @@ def test_update_internal_user_params():
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai") data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
data_json = data.model_dump() data_json = data.model_dump()
updated_data_json = _update_internal_new_user_params(data_json, data) updated_data_json = _update_internal_user_params(data_json, data)
assert updated_data_json["models"] == litellm.default_internal_user_params["models"] assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
assert ( assert (
updated_data_json["max_budget"] updated_data_json["max_budget"]
@ -530,7 +530,7 @@ def test_prepare_key_update_data():
data = UpdateKeyRequest(key="test_key", metadata=None) data = UpdateKeyRequest(key="test_key", metadata=None)
updated_data = prepare_key_update_data(data, existing_key_row) updated_data = prepare_key_update_data(data, existing_key_row)
assert updated_data["metadata"] is None assert updated_data["metadata"] == None
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -300,7 +300,6 @@ async def test_key_update(metadata):
get_key=key, get_key=key,
metadata=metadata, metadata=metadata,
) )
print(f"updated_key['metadata']: {updated_key['metadata']}")
assert updated_key["metadata"] == metadata assert updated_key["metadata"] == metadata
await update_proxy_budget(session=session) # resets proxy spend await update_proxy_budget(session=session) # resets proxy spend
await chat_completion(session=session, key=key) await chat_completion(session=session, key=key)

View file

@ -114,7 +114,7 @@ async def test_spend_logs():
async def get_predict_spend_logs(session): async def get_predict_spend_logs(session):
url = "http://0.0.0.0:4000/global/predict/spend/logs" url = f"http://0.0.0.0:4000/global/predict/spend/logs"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"data": [ "data": [
@ -155,7 +155,6 @@ async def get_spend_report(session, start_date, end_date):
return await response.json() return await response.json()
@pytest.mark.skip(reason="datetime in ci/cd gets set weirdly")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_predicted_spend_logs(): async def test_get_predicted_spend_logs():
""" """