Ollama ssl verify = False + Spend Logs reliability fixes (#7931)

* fix(http_handler.py): support passing ssl verify dynamically and using the correct httpx client based on passed ssl verify param

Fixes https://github.com/BerriAI/litellm/issues/6499

* feat(llm_http_handler.py): support passing `ssl_verify=False` dynamically in call args

Closes https://github.com/BerriAI/litellm/issues/6499

* fix(proxy/utils.py): prevent bad logs from breaking all cost tracking + reset list regardless of success/failure

prevents malformed logs from causing all spend tracking to break since they're constantly retried

* test(test_proxy_utils.py): add test to ensure bad log is dropped

* test(test_proxy_utils.py): ensure in-memory spend logs reset after bad log error

* test(test_user_api_key_auth.py): add unit test to ensure end user id as str works

* fix(auth_utils.py): ensure extracted end user id is always a str

prevents db cost tracking errors

* test(test_auth_utils.py): ensure get end user id from request body always returns a string

* test: update tests

* test: skip bedrock test- behaviour now supported

* test: fix testing

* refactor(spend_tracking_utils.py): reduce size of get_logging_payload

* test: fix test

* bump: version 1.59.4 → 1.59.5

* Revert "bump: version 1.59.4 → 1.59.5"

This reverts commit 1182b46b2e.

* fix(utils.py): fix spend logs retry logic

* fix(spend_tracking_utils.py): fix get tags

* fix(spend_tracking_utils.py): fix end user id spend tracking on pass-through endpoints
This commit is contained in:
Krish Dholakia 2025-01-23 23:05:41 -08:00 committed by GitHub
parent b94f60632a
commit e6e4da75d7
17 changed files with 406 additions and 187 deletions

View file

@ -93,11 +93,15 @@ class AsyncHTTPHandler:
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None, event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
concurrent_limit=1000, concurrent_limit=1000,
client_alias: Optional[str] = None, # name for client in logs client_alias: Optional[str] = None, # name for client in logs
ssl_verify: Optional[Union[bool, str]] = None,
): ):
self.timeout = timeout self.timeout = timeout
self.event_hooks = event_hooks self.event_hooks = event_hooks
self.client = self.create_client( self.client = self.create_client(
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks timeout=timeout,
concurrent_limit=concurrent_limit,
event_hooks=event_hooks,
ssl_verify=ssl_verify,
) )
self.client_alias = client_alias self.client_alias = client_alias
@ -106,10 +110,12 @@ class AsyncHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
concurrent_limit: int, concurrent_limit: int,
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
ssl_verify: Optional[Union[bool, str]] = None,
) -> httpx.AsyncClient: ) -> httpx.AsyncClient:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem # /path/to/certificate.pem
if ssl_verify is None:
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client. # An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem # /path/to/client.pem
@ -440,13 +446,17 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000, concurrent_limit=1000,
client: Optional[httpx.Client] = None, client: Optional[httpx.Client] = None,
ssl_verify: Optional[Union[bool, str]] = None,
): ):
if timeout is None: if timeout is None:
timeout = _DEFAULT_TIMEOUT timeout = _DEFAULT_TIMEOUT
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem # /path/to/certificate.pem
if ssl_verify is None:
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client. # An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem # /path/to/client.pem
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
@ -506,7 +516,15 @@ class HTTPHandler:
try: try:
if timeout is not None: if timeout is not None:
req = self.client.build_request( req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore "POST",
url,
data=data, # type: ignore
json=json,
params=params,
headers=headers,
timeout=timeout,
files=files,
content=content, # type: ignore
) )
else: else:
req = self.client.build_request( req = self.client.build_request(
@ -660,6 +678,7 @@ def get_async_httpx_client(
_new_client = AsyncHTTPHandler( _new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) timeout=httpx.Timeout(timeout=600.0, connect=5.0)
) )
litellm.in_memory_llm_clients_cache.set_cache( litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name, key=_cache_key_name,
value=_new_client, value=_new_client,
@ -684,6 +703,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
pass pass
_cache_key_name = "httpx_client" + _params_key_name _cache_key_name = "httpx_client" + _params_key_name
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client: if _cached_client:
return _cached_client return _cached_client

View file

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Union
import httpx import httpx
@ -36,13 +36,13 @@ class HTTPHandler:
async def post( async def post(
self, self,
url: str, url: str,
data: Optional[dict] = None, data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
): ):
try: try:
response = await self.client.post( response = await self.client.post(
url, data=data, params=params, headers=headers url, data=data, params=params, headers=headers # type: ignore
) )
return response return response
except Exception as e: except Exception as e:

View file

@ -158,7 +158,8 @@ class BaseLLMHTTPHandler:
): ):
if client is None: if client is None:
async_httpx_client = get_async_httpx_client( async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider) llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
) )
else: else:
async_httpx_client = client async_httpx_client = client
@ -318,7 +319,9 @@ class BaseLLMHTTPHandler:
) )
if client is None or not isinstance(client, HTTPHandler): if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client() sync_httpx_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else: else:
sync_httpx_client = client sync_httpx_client = client
@ -359,7 +362,11 @@ class BaseLLMHTTPHandler:
client: Optional[HTTPHandler] = None, client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]: ) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler): if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client() sync_httpx_client = _get_httpx_client(
{
"ssl_verify": litellm_params.get("ssl_verify", None),
}
)
else: else:
sync_httpx_client = client sync_httpx_client = client
stream = True stream = True
@ -411,7 +418,7 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
): ):
completion_stream, _response_headers = await self.make_async_call( completion_stream, _response_headers = await self.make_async_call_stream_helper(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
provider_config=provider_config, provider_config=provider_config,
api_base=api_base, api_base=api_base,
@ -432,7 +439,7 @@ class BaseLLMHTTPHandler:
) )
return streamwrapper return streamwrapper
async def make_async_call( async def make_async_call_stream_helper(
self, self,
custom_llm_provider: str, custom_llm_provider: str,
provider_config: BaseConfig, provider_config: BaseConfig,
@ -446,9 +453,15 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
"""
Helper function for making an async call with stream.
Handles fake stream as well.
"""
if client is None: if client is None:
async_httpx_client = get_async_httpx_client( async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider) llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
) )
else: else:
async_httpx_client = client async_httpx_client = client

View file

@ -855,6 +855,8 @@ def completion( # type: ignore # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None) cooldown_time = kwargs.get("cooldown_time", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None) organization = kwargs.get("organization", None)
### VERIFY SSL ###
ssl_verify = kwargs.get("ssl_verify", None)
### CUSTOM MODEL COST ### ### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None) input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -1102,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
drop_params=kwargs.get("drop_params"), drop_params=kwargs.get("drop_params"),
prompt_id=prompt_id, prompt_id=prompt_id,
prompt_variables=prompt_variables, prompt_variables=prompt_variables,
ssl_verify=ssl_verify,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -0,0 +1,14 @@
model_list:
- model_name: bedrock-claude
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_region_name: us-east-1
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
litellm_settings:
callbacks: ["datadog"] # logs llm success + failure logs on datadog
service_callback: ["datadog"] # logs redis, postgres failures on datadog
general_settings:
store_prompts_in_spend_logs: true

View file

@ -11,7 +11,3 @@ model_list:
api_base: http://0.0.0.0:8090 api_base: http://0.0.0.0:8090
timeout: 2 timeout: 2
num_retries: 0 num_retries: 0
litellm_settings:
success_callback: ["langfuse"]

View file

@ -1483,7 +1483,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
# Check if the value is None and set the corresponding attribute # Check if the value is None and set the corresponding attribute
if getattr(self, attr_name, None) is None: if getattr(self, attr_name, None) is None:
kwargs[attr_name] = value kwargs[attr_name] = value
if key == "end_user_id" and value is not None and isinstance(value, int):
kwargs[key] = str(value)
# Initialize the superclass # Initialize the superclass
super().__init__(**kwargs) super().__init__(**kwargs)

View file

@ -498,13 +498,13 @@ def _has_user_setup_sso():
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
# openai - check 'user' # openai - check 'user'
if "user" in request_body: if "user" in request_body and request_body["user"] is not None:
return request_body["user"] return str(request_body["user"])
# anthropic - check 'litellm_metadata' # anthropic - check 'litellm_metadata'
end_user_id = request_body.get("litellm_metadata", {}).get("user", None) end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
if end_user_id: if end_user_id:
return end_user_id return str(end_user_id)
metadata = request_body.get("metadata") metadata = request_body.get("metadata")
if metadata and "user_id" in metadata: if metadata and "user_id" in metadata and metadata["user_id"] is not None:
return metadata["user_id"] return str(metadata["user_id"])
return None return None

View file

@ -1051,7 +1051,6 @@ async def update_database( # noqa: PLR0915
response_obj=completion_response, response_obj=completion_response,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
end_user_id=end_user_id,
) )
payload["spend"] = response_cost payload["spend"] = response_cost

View file

@ -10,6 +10,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
@ -29,15 +30,36 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False return False
def get_logging_payload( def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str] if metadata is None:
) -> SpendLogsPayload: return SpendLogsMetadata(
user_api_key=None,
from litellm.proxy.proxy_server import general_settings, master_key user_api_key_alias=None,
user_api_key_team_id=None,
verbose_proxy_logger.debug( user_api_key_user_id=None,
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n" user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
) )
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
return clean_metadata
def get_logging_payload(kwargs, response_obj, start_time, end_time) -> SpendLogsPayload:
from litellm.proxy.proxy_server import general_settings, master_key
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
@ -57,14 +79,23 @@ def get_logging_payload(
if isinstance(usage, litellm.Usage): if isinstance(usage, litellm.Usage):
usage = dict(usage) usage = dict(usage)
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id") id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
api_key = metadata.get("user_api_key", "") standard_logging_payload = cast(
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
"standard_logging_object", None )
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
if standard_logging_payload is not None:
api_key = standard_logging_payload["metadata"].get("user_api_key_hash") or ""
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
"user_api_key_end_user_id"
)
else:
api_key = ""
request_tags = (
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), list)
else "[]"
) )
if api_key is not None and isinstance(api_key, str):
if api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
if ( if (
_is_master_key(api_key=api_key, _master_key=master_key) _is_master_key(api_key=api_key, _master_key=master_key)
and general_settings.get("disable_adding_master_key_hash_to_db") is True and general_settings.get("disable_adding_master_key_hash_to_db") is True
@ -74,37 +105,8 @@ def get_logging_payload(
_model_id = metadata.get("model_info", {}).get("id", "") _model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "") _model_group = metadata.get("model_group", "")
request_tags = (
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), list)
else "[]"
)
# clean up litellm metadata # clean up litellm metadata
clean_metadata = SpendLogsMetadata( clean_metadata = _get_spend_logs_metadata(metadata)
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {} additional_usage_values = {}

View file

@ -2432,6 +2432,126 @@ async def reset_budget(prisma_client: PrismaClient):
) )
class ProxyUpdateSpend:
@staticmethod
async def update_end_user_spend(
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
):
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
where={"user_id": end_user_id},
data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
finally:
prisma_client.end_user_list_transactons = (
{}
) # reset the end user list transactions - prevent bad data from causing issues
@staticmethod
async def update_spend_logs(
n_retry_times: int,
prisma_client: PrismaClient,
db_writer_client: Optional[HTTPHandler],
proxy_logging_obj: ProxyLogging,
):
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
MAX_LOGS_PER_INTERVAL = (
1000 # Maximum number of logs to flush in a single interval
)
# Get initial logs to process
logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL]
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
if (
len(logs_to_process) > 0
and base_url is not None
and db_writer_client is not None
):
if not base_url.endswith("/"):
base_url += "/"
verbose_proxy_logger.debug("base_url: {}".format(base_url))
response = await db_writer_client.post(
url=base_url + "spend/update",
data=json.dumps(logs_to_process),
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[
len(logs_to_process) :
]
)
else:
for j in range(0, len(logs_to_process), BATCH_SIZE):
batch = logs_to_process[j : j + BATCH_SIZE]
batch_with_dates = [
prisma_client.jsonify_object({**entry})
for entry in batch
]
await prisma_client.db.litellm_spendlogs.create_many(
data=batch_with_dates, skip_duplicates=True
)
verbose_proxy_logger.debug(
f"Flushed {len(batch)} logs to the DB."
)
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[len(logs_to_process) :]
)
verbose_proxy_logger.debug(
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
)
break
except DB_CONNECTION_ERROR_TYPES:
if i is None:
i = 0
if i >= n_retry_times:
raise
await asyncio.sleep(2**i)
except Exception as e:
prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[
len(logs_to_process) :
]
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
async def update_spend( # noqa: PLR0915 async def update_spend( # noqa: PLR0915
prisma_client: PrismaClient, prisma_client: PrismaClient,
db_writer_client: Optional[HTTPHandler], db_writer_client: Optional[HTTPHandler],
@ -2490,47 +2610,11 @@ async def update_spend( # noqa: PLR0915
) )
) )
if len(prisma_client.end_user_list_transactons.keys()) > 0: if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1): await ProxyUpdateSpend.update_end_user_spend(
start_time = time.time() n_retry_times=n_retry_times,
try: prisma_client=prisma_client,
async with prisma_client.db.tx( proxy_logging_obj=proxy_logging_obj,
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
where={"user_id": end_user_id},
data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
) )
prisma_client.end_user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
### UPDATE KEY TABLE ### ### UPDATE KEY TABLE ###
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"KEY Spend transactions: {}".format( "KEY Spend transactions: {}".format(
@ -2687,79 +2771,12 @@ async def update_spend( # noqa: PLR0915
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions)) "Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
) )
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
if len(prisma_client.spend_log_transactions) > 0: if len(prisma_client.spend_log_transactions) > 0:
for i in range(n_retry_times + 1): await ProxyUpdateSpend.update_spend_logs(
start_time = time.time() n_retry_times=n_retry_times,
try: prisma_client=prisma_client,
base_url = os.getenv("SPEND_LOGS_URL", None) proxy_logging_obj=proxy_logging_obj,
## WRITE TO SEPARATE SERVER ## db_writer_client=db_writer_client,
if (
len(prisma_client.spend_log_transactions) > 0
and base_url is not None
and db_writer_client is not None
):
if not base_url.endswith("/"):
base_url += "/"
verbose_proxy_logger.debug("base_url: {}".format(base_url))
response = await db_writer_client.post(
url=base_url + "spend/update",
data=json.dumps(prisma_client.spend_log_transactions), # type: ignore
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
prisma_client.spend_log_transactions = []
else: ## (default) WRITE TO DB ##
logs_to_process = prisma_client.spend_log_transactions[
:MAX_LOGS_PER_INTERVAL
]
for j in range(0, len(logs_to_process), BATCH_SIZE):
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
batch = logs_to_process[j : j + BATCH_SIZE]
# Convert datetime strings to Date objects
batch_with_dates = [
prisma_client.jsonify_object(
{
**entry,
}
)
for entry in batch
]
await prisma_client.db.litellm_spendlogs.create_many(
data=batch_with_dates, skip_duplicates=True # type: ignore
)
verbose_proxy_logger.debug(
f"Flushed {len(batch)} logs to the DB."
)
# Remove the processed logs from spend_logs
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[len(logs_to_process) :]
)
verbose_proxy_logger.debug(
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i is None:
i = 0
if (
i >= n_retry_times
): # If we've reached the maximum number of retries raise the exception
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # type: ignore
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
) )

View file

@ -2132,6 +2132,7 @@ def get_litellm_params(
prompt_id: Optional[str] = None, prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None, prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None, async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
litellm_params = { litellm_params = {
@ -2170,6 +2171,7 @@ def get_litellm_params(
"prompt_id": prompt_id, "prompt_id": prompt_id,
"prompt_variables": prompt_variables, "prompt_variables": prompt_variables,
"async_call": async_call, "async_call": async_call,
"ssl_verify": ssl_verify,
} }
return litellm_params return litellm_params

View file

@ -68,3 +68,12 @@ def test_configurable_clientside_parameters(
) )
print(resp) print(resp)
assert resp == should_return_true assert resp == should_return_true
def test_get_end_user_id_from_request_body_always_returns_str():
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
request_body = {"user": 123}
end_user_id = get_end_user_id_from_request_body(request_body)
assert end_user_id == "123"
assert isinstance(end_user_id, str)

View file

@ -174,3 +174,67 @@ def test_ollama_chat_function_calling():
print(json.loads(tool_calls[0].function.arguments)) print(json.loads(tool_calls[0].function.arguments))
print(response) print(response)
def test_ollama_ssl_verify():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import ssl
import httpx
try:
response = litellm.completion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
)
except Exception as e:
print(e)
client: HTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"httpx_clientssl_verify_False"
)
test_client = httpx.Client(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_async_ollama_ssl_verify(stream):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
try:
response = await litellm.acompletion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
stream=stream,
)
except Exception as e:
print(e)
client: AsyncHTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"async_httpx_clientssl_verify_Falseollama"
)
test_client = httpx.AsyncClient(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)

View file

@ -170,6 +170,12 @@ def test_spend_logs_payload(model_id: Optional[str]):
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146), "end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"cache_hit": None, "cache_hit": None,
"response_cost": 2.4999999999999998e-05, "response_cost": 2.4999999999999998e-05,
"standard_logging_object": {
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
"metadata": {
"user_api_key_end_user_id": "test-user",
},
},
}, },
"response_obj": litellm.ModelResponse( "response_obj": litellm.ModelResponse(
id=model_id, id=model_id,
@ -192,7 +198,6 @@ def test_spend_logs_payload(model_id: Optional[str]):
), ),
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604), "start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604),
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146), "end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"end_user_id": None,
} }
payload: SpendLogsPayload = get_logging_payload(**input_args) payload: SpendLogsPayload = get_logging_payload(**input_args)
@ -229,6 +234,7 @@ def test_spend_logs_payload_whisper():
"metadata": { "metadata": {
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", "user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": None, "user_api_key_alias": None,
"user_api_key_end_user_id": "test-user",
"user_api_end_user_max_budget": None, "user_api_end_user_max_budget": None,
"litellm_api_version": "1.40.19", "litellm_api_version": "1.40.19",
"global_max_parallel_requests": None, "global_max_parallel_requests": None,
@ -293,7 +299,6 @@ def test_spend_logs_payload_whisper():
response_obj=response, response_obj=response,
start_time=datetime.datetime.now(), start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(), end_time=datetime.datetime.now(),
end_user_id="test-user",
) )
print("payload: ", payload) print("payload: ", payload)
@ -335,13 +340,16 @@ def test_spend_logs_payload_with_prompts_enabled(monkeypatch):
), ),
"start_time": datetime.datetime.now(), "start_time": datetime.datetime.now(),
"end_time": datetime.datetime.now(), "end_time": datetime.datetime.now(),
"end_user_id": "user123",
} }
# Create a standard logging payload # Create a standard logging payload
standard_logging_payload = { standard_logging_payload = {
"messages": [{"role": "user", "content": "Hello!"}], "messages": [{"role": "user", "content": "Hello!"}],
"response": {"role": "assistant", "content": "Hi there!"}, "response": {"role": "assistant", "content": "Hi there!"},
"metadata": {
"user_api_key_end_user_id": "test-user",
},
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
} }
input_args["kwargs"]["standard_logging_object"] = standard_logging_payload input_args["kwargs"]["standard_logging_object"] = standard_logging_payload

View file

@ -1497,6 +1497,62 @@ def test_custom_openapi(mock_get_openapi_schema):
assert openapi_schema is not None assert openapi_schema is not None
import pytest
from unittest.mock import MagicMock, AsyncMock
import asyncio
from datetime import timedelta
from litellm.proxy.utils import ProxyUpdateSpend
@pytest.mark.asyncio
async def test_end_user_transactions_reset():
# Setup
mock_client = MagicMock()
mock_client.end_user_list_transactons = {"1": 10.0} # Bad log
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
# Call function - should raise error
with pytest.raises(Exception):
await ProxyUpdateSpend.update_end_user_spend(
n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
)
# Verify cleanup happened
assert (
mock_client.end_user_list_transactons == {}
), "Transactions list should be empty after error"
@pytest.mark.asyncio
async def test_spend_logs_cleanup_after_error():
# Setup test data
mock_client = MagicMock()
mock_client.spend_log_transactions = [
{"id": 1, "amount": 10.0},
{"id": 2, "amount": 20.0},
{"id": 3, "amount": 30.0},
]
# Make the DB operation fail
mock_client.db.litellm_spendlogs.create_many = AsyncMock(
side_effect=Exception("DB Error")
)
original_logs = mock_client.spend_log_transactions.copy()
# Call function - should raise error
with pytest.raises(Exception):
await ProxyUpdateSpend.update_spend_logs(
n_retry_times=0,
prisma_client=mock_client,
db_writer_client=None, # Test DB write path
proxy_logging_obj=MagicMock(),
)
# Verify the first batch was removed from spend_log_transactions
assert (
mock_client.spend_log_transactions == original_logs[100:]
), "Should remove processed logs even after error"
def test_provider_specific_header(): def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import ( from litellm.proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request, add_provider_specific_headers_to_request,

View file

@ -862,3 +862,18 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
await _jwt_auth_user_api_key_auth_builder(**args) await _jwt_auth_user_api_key_auth_builder(**args)
else: else:
await _jwt_auth_user_api_key_auth_builder(**args) await _jwt_auth_user_api_key_auth_builder(**args)
def test_user_api_key_auth_end_user_str():
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
user_api_key_args = {
"api_key": "sk-1234",
"parent_otel_span": None,
"user_role": LitellmUserRoles.PROXY_ADMIN,
"end_user_id": "1",
"user_id": "default_user_id",
}
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
assert user_api_key_auth.end_user_id == "1"