Ollama ssl verify = False + Spend Logs reliability fixes (#7931)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s

* fix(http_handler.py): support passing ssl verify dynamically and using the correct httpx client based on passed ssl verify param

Fixes https://github.com/BerriAI/litellm/issues/6499

* feat(llm_http_handler.py): support passing `ssl_verify=False` dynamically in call args

Closes https://github.com/BerriAI/litellm/issues/6499

* fix(proxy/utils.py): prevent bad logs from breaking all cost tracking + reset list regardless of success/failure

prevents malformed logs from causing all spend tracking to break since they're constantly retried

* test(test_proxy_utils.py): add test to ensure bad log is dropped

* test(test_proxy_utils.py): ensure in-memory spend logs reset after bad log error

* test(test_user_api_key_auth.py): add unit test to ensure end user id as str works

* fix(auth_utils.py): ensure extracted end user id is always a str

prevents db cost tracking errors

* test(test_auth_utils.py): ensure get end user id from request body always returns a string

* test: update tests

* test: skip bedrock test- behaviour now supported

* test: fix testing

* refactor(spend_tracking_utils.py): reduce size of get_logging_payload

* test: fix test

* bump: version 1.59.4 → 1.59.5

* Revert "bump: version 1.59.4 → 1.59.5"

This reverts commit 1182b46b2e.

* fix(utils.py): fix spend logs retry logic

* fix(spend_tracking_utils.py): fix get tags

* fix(spend_tracking_utils.py): fix end user id spend tracking on pass-through endpoints
This commit is contained in:
Krish Dholakia 2025-01-23 23:05:41 -08:00 committed by GitHub
parent 851b0c4c4d
commit 1e011b66d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 406 additions and 187 deletions

View file

@ -93,11 +93,15 @@ class AsyncHTTPHandler:
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
concurrent_limit=1000,
client_alias: Optional[str] = None, # name for client in logs
ssl_verify: Optional[Union[bool, str]] = None,
):
self.timeout = timeout
self.event_hooks = event_hooks
self.client = self.create_client(
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks
timeout=timeout,
concurrent_limit=concurrent_limit,
event_hooks=event_hooks,
ssl_verify=ssl_verify,
)
self.client_alias = client_alias
@ -106,11 +110,13 @@ class AsyncHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]],
concurrent_limit: int,
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
ssl_verify: Optional[Union[bool, str]] = None,
) -> httpx.AsyncClient:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
if ssl_verify is None:
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
@ -440,13 +446,17 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
ssl_verify: Optional[Union[bool, str]] = None,
):
if timeout is None:
timeout = _DEFAULT_TIMEOUT
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
if ssl_verify is None:
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
# An SSL certificate used by the requested host to authenticate the client.
# /path/to/client.pem
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
@ -506,7 +516,15 @@ class HTTPHandler:
try:
if timeout is not None:
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore
"POST",
url,
data=data, # type: ignore
json=json,
params=params,
headers=headers,
timeout=timeout,
files=files,
content=content, # type: ignore
)
else:
req = self.client.build_request(
@ -660,6 +678,7 @@ def get_async_httpx_client(
_new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name,
value=_new_client,
@ -684,6 +703,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
pass
_cache_key_name = "httpx_client" + _params_key_name
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client:
return _cached_client

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
import httpx
@ -36,13 +36,13 @@ class HTTPHandler:
async def post(
self,
url: str,
data: Optional[dict] = None,
data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
try:
response = await self.client.post(
url, data=data, params=params, headers=headers
url, data=data, params=params, headers=headers # type: ignore
)
return response
except Exception as e:

View file

@ -158,7 +158,8 @@ class BaseLLMHTTPHandler:
):
if client is None:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
async_httpx_client = client
@ -318,7 +319,9 @@ class BaseLLMHTTPHandler:
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
sync_httpx_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else:
sync_httpx_client = client
@ -359,7 +362,11 @@ class BaseLLMHTTPHandler:
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
sync_httpx_client = _get_httpx_client(
{
"ssl_verify": litellm_params.get("ssl_verify", None),
}
)
else:
sync_httpx_client = client
stream = True
@ -411,7 +418,7 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
):
completion_stream, _response_headers = await self.make_async_call(
completion_stream, _response_headers = await self.make_async_call_stream_helper(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
@ -432,7 +439,7 @@ class BaseLLMHTTPHandler:
)
return streamwrapper
async def make_async_call(
async def make_async_call_stream_helper(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
@ -446,9 +453,15 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
"""
Helper function for making an async call with stream.
Handles fake stream as well.
"""
if client is None:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
llm_provider=litellm.LlmProviders(custom_llm_provider),
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
async_httpx_client = client

View file

@ -855,6 +855,8 @@ def completion( # type: ignore # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### VERIFY SSL ###
ssl_verify = kwargs.get("ssl_verify", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -1102,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
drop_params=kwargs.get("drop_params"),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
ssl_verify=ssl_verify,
)
logging.update_environment_variables(
model=model,

View file

@ -0,0 +1,14 @@
model_list:
- model_name: bedrock-claude
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_region_name: us-east-1
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
litellm_settings:
callbacks: ["datadog"] # logs llm success + failure logs on datadog
service_callback: ["datadog"] # logs redis, postgres failures on datadog
general_settings:
store_prompts_in_spend_logs: true

View file

@ -11,7 +11,3 @@ model_list:
api_base: http://0.0.0.0:8090
timeout: 2
num_retries: 0
litellm_settings:
success_callback: ["langfuse"]

View file

@ -1483,7 +1483,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
# Check if the value is None and set the corresponding attribute
if getattr(self, attr_name, None) is None:
kwargs[attr_name] = value
if key == "end_user_id" and value is not None and isinstance(value, int):
kwargs[key] = str(value)
# Initialize the superclass
super().__init__(**kwargs)

View file

@ -498,13 +498,13 @@ def _has_user_setup_sso():
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
# openai - check 'user'
if "user" in request_body:
return request_body["user"]
if "user" in request_body and request_body["user"] is not None:
return str(request_body["user"])
# anthropic - check 'litellm_metadata'
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
if end_user_id:
return end_user_id
return str(end_user_id)
metadata = request_body.get("metadata")
if metadata and "user_id" in metadata:
return metadata["user_id"]
if metadata and "user_id" in metadata and metadata["user_id"] is not None:
return str(metadata["user_id"])
return None

View file

@ -1051,7 +1051,6 @@ async def update_database( # noqa: PLR0915
response_obj=completion_response,
start_time=start_time,
end_time=end_time,
end_user_id=end_user_id,
)
payload["spend"] = response_cost

View file

@ -10,6 +10,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
@ -29,16 +30,37 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False
def get_logging_payload(
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
) -> SpendLogsPayload:
from litellm.proxy.proxy_server import general_settings, master_key
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
)
verbose_proxy_logger.debug(
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
return clean_metadata
def get_logging_payload(kwargs, response_obj, start_time, end_time) -> SpendLogsPayload:
from litellm.proxy.proxy_server import general_settings, master_key
if kwargs is None:
kwargs = {}
if response_obj is None or (
@ -57,54 +79,34 @@ def get_logging_payload(
if isinstance(usage, litellm.Usage):
usage = dict(usage)
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
api_key = metadata.get("user_api_key", "")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
standard_logging_payload = cast(
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
)
if api_key is not None and isinstance(api_key, str):
if api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
if (
_is_master_key(api_key=api_key, _master_key=master_key)
and general_settings.get("disable_adding_master_key_hash_to_db") is True
):
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
_model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "")
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
if standard_logging_payload is not None:
api_key = standard_logging_payload["metadata"].get("user_api_key_hash") or ""
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
"user_api_key_end_user_id"
)
else:
api_key = ""
request_tags = (
json.dumps(metadata.get("tags", []))
if isinstance(metadata.get("tags", []), list)
else "[]"
)
if (
_is_master_key(api_key=api_key, _master_key=master_key)
and general_settings.get("disable_adding_master_key_hash_to_db") is True
):
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
_model_id = metadata.get("model_info", {}).get("id", "")
_model_group = metadata.get("model_group", "")
# clean up litellm metadata
clean_metadata = SpendLogsMetadata(
user_api_key=None,
user_api_key_alias=None,
user_api_key_team_id=None,
user_api_key_user_id=None,
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
)
if isinstance(metadata, dict):
verbose_proxy_logger.debug(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
# Filter the metadata dictionary to include only the specified keys
clean_metadata = SpendLogsMetadata(
**{ # type: ignore
key: metadata[key]
for key in SpendLogsMetadata.__annotations__.keys()
if key in metadata
}
)
clean_metadata = _get_spend_logs_metadata(metadata)
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}

View file

@ -2432,6 +2432,126 @@ async def reset_budget(prisma_client: PrismaClient):
)
class ProxyUpdateSpend:
@staticmethod
async def update_end_user_spend(
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
):
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
where={"user_id": end_user_id},
data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
finally:
prisma_client.end_user_list_transactons = (
{}
) # reset the end user list transactions - prevent bad data from causing issues
@staticmethod
async def update_spend_logs(
n_retry_times: int,
prisma_client: PrismaClient,
db_writer_client: Optional[HTTPHandler],
proxy_logging_obj: ProxyLogging,
):
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
MAX_LOGS_PER_INTERVAL = (
1000 # Maximum number of logs to flush in a single interval
)
# Get initial logs to process
logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL]
start_time = time.time()
try:
for i in range(n_retry_times + 1):
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
if (
len(logs_to_process) > 0
and base_url is not None
and db_writer_client is not None
):
if not base_url.endswith("/"):
base_url += "/"
verbose_proxy_logger.debug("base_url: {}".format(base_url))
response = await db_writer_client.post(
url=base_url + "spend/update",
data=json.dumps(logs_to_process),
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[
len(logs_to_process) :
]
)
else:
for j in range(0, len(logs_to_process), BATCH_SIZE):
batch = logs_to_process[j : j + BATCH_SIZE]
batch_with_dates = [
prisma_client.jsonify_object({**entry})
for entry in batch
]
await prisma_client.db.litellm_spendlogs.create_many(
data=batch_with_dates, skip_duplicates=True
)
verbose_proxy_logger.debug(
f"Flushed {len(batch)} logs to the DB."
)
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[len(logs_to_process) :]
)
verbose_proxy_logger.debug(
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
)
break
except DB_CONNECTION_ERROR_TYPES:
if i is None:
i = 0
if i >= n_retry_times:
raise
await asyncio.sleep(2**i)
except Exception as e:
prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[
len(logs_to_process) :
]
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
async def update_spend( # noqa: PLR0915
prisma_client: PrismaClient,
db_writer_client: Optional[HTTPHandler],
@ -2490,47 +2610,11 @@ async def update_spend( # noqa: PLR0915
)
)
if len(prisma_client.end_user_list_transactons.keys()) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
async with prisma_client.db.tx(
timeout=timedelta(seconds=60)
) as transaction:
async with transaction.batch_() as batcher:
for (
end_user_id,
response_cost,
) in prisma_client.end_user_list_transactons.items():
if litellm.max_end_user_budget is not None:
pass
batcher.litellm_endusertable.upsert(
where={"user_id": end_user_id},
data={
"create": {
"user_id": end_user_id,
"spend": response_cost,
"blocked": False,
},
"update": {"spend": {"increment": response_cost}},
},
)
prisma_client.end_user_list_transactons = (
{}
) # Clear the remaining transactions after processing all batches in the loop.
break
except DB_CONNECTION_ERROR_TYPES as e:
if i >= n_retry_times: # If we've reached the maximum number of retries
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
await ProxyUpdateSpend.update_end_user_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
)
### UPDATE KEY TABLE ###
verbose_proxy_logger.debug(
"KEY Spend transactions: {}".format(
@ -2687,80 +2771,13 @@ async def update_spend( # noqa: PLR0915
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
)
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
if len(prisma_client.spend_log_transactions) > 0:
for i in range(n_retry_times + 1):
start_time = time.time()
try:
base_url = os.getenv("SPEND_LOGS_URL", None)
## WRITE TO SEPARATE SERVER ##
if (
len(prisma_client.spend_log_transactions) > 0
and base_url is not None
and db_writer_client is not None
):
if not base_url.endswith("/"):
base_url += "/"
verbose_proxy_logger.debug("base_url: {}".format(base_url))
response = await db_writer_client.post(
url=base_url + "spend/update",
data=json.dumps(prisma_client.spend_log_transactions), # type: ignore
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
prisma_client.spend_log_transactions = []
else: ## (default) WRITE TO DB ##
logs_to_process = prisma_client.spend_log_transactions[
:MAX_LOGS_PER_INTERVAL
]
for j in range(0, len(logs_to_process), BATCH_SIZE):
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
batch = logs_to_process[j : j + BATCH_SIZE]
# Convert datetime strings to Date objects
batch_with_dates = [
prisma_client.jsonify_object(
{
**entry,
}
)
for entry in batch
]
await prisma_client.db.litellm_spendlogs.create_many(
data=batch_with_dates, skip_duplicates=True # type: ignore
)
verbose_proxy_logger.debug(
f"Flushed {len(batch)} logs to the DB."
)
# Remove the processed logs from spend_logs
prisma_client.spend_log_transactions = (
prisma_client.spend_log_transactions[len(logs_to_process) :]
)
verbose_proxy_logger.debug(
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
)
break
except DB_CONNECTION_ERROR_TYPES as e:
if i is None:
i = 0
if (
i >= n_retry_times
): # If we've reached the maximum number of retries raise the exception
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # type: ignore
except Exception as e:
_raise_failed_update_spend_exception(
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
)
await ProxyUpdateSpend.update_spend_logs(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
db_writer_client=db_writer_client,
)
def _raise_failed_update_spend_exception(

View file

@ -2132,6 +2132,7 @@ def get_litellm_params(
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
**kwargs,
) -> dict:
litellm_params = {
@ -2170,6 +2171,7 @@ def get_litellm_params(
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
"async_call": async_call,
"ssl_verify": ssl_verify,
}
return litellm_params

View file

@ -68,3 +68,12 @@ def test_configurable_clientside_parameters(
)
print(resp)
assert resp == should_return_true
def test_get_end_user_id_from_request_body_always_returns_str():
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
request_body = {"user": 123}
end_user_id = get_end_user_id_from_request_body(request_body)
assert end_user_id == "123"
assert isinstance(end_user_id, str)

View file

@ -174,3 +174,67 @@ def test_ollama_chat_function_calling():
print(json.loads(tool_calls[0].function.arguments))
print(response)
def test_ollama_ssl_verify():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import ssl
import httpx
try:
response = litellm.completion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
)
except Exception as e:
print(e)
client: HTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"httpx_clientssl_verify_False"
)
test_client = httpx.Client(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_async_ollama_ssl_verify(stream):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
try:
response = await litellm.acompletion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
stream=stream,
)
except Exception as e:
print(e)
client: AsyncHTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"async_httpx_clientssl_verify_Falseollama"
)
test_client = httpx.AsyncClient(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)

View file

@ -170,6 +170,12 @@ def test_spend_logs_payload(model_id: Optional[str]):
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"cache_hit": None,
"response_cost": 2.4999999999999998e-05,
"standard_logging_object": {
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
"metadata": {
"user_api_key_end_user_id": "test-user",
},
},
},
"response_obj": litellm.ModelResponse(
id=model_id,
@ -192,7 +198,6 @@ def test_spend_logs_payload(model_id: Optional[str]):
),
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604),
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
"end_user_id": None,
}
payload: SpendLogsPayload = get_logging_payload(**input_args)
@ -229,6 +234,7 @@ def test_spend_logs_payload_whisper():
"metadata": {
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"user_api_key_alias": None,
"user_api_key_end_user_id": "test-user",
"user_api_end_user_max_budget": None,
"litellm_api_version": "1.40.19",
"global_max_parallel_requests": None,
@ -293,7 +299,6 @@ def test_spend_logs_payload_whisper():
response_obj=response,
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(),
end_user_id="test-user",
)
print("payload: ", payload)
@ -335,13 +340,16 @@ def test_spend_logs_payload_with_prompts_enabled(monkeypatch):
),
"start_time": datetime.datetime.now(),
"end_time": datetime.datetime.now(),
"end_user_id": "user123",
}
# Create a standard logging payload
standard_logging_payload = {
"messages": [{"role": "user", "content": "Hello!"}],
"response": {"role": "assistant", "content": "Hi there!"},
"metadata": {
"user_api_key_end_user_id": "test-user",
},
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
}
input_args["kwargs"]["standard_logging_object"] = standard_logging_payload

View file

@ -1497,6 +1497,62 @@ def test_custom_openapi(mock_get_openapi_schema):
assert openapi_schema is not None
import pytest
from unittest.mock import MagicMock, AsyncMock
import asyncio
from datetime import timedelta
from litellm.proxy.utils import ProxyUpdateSpend
@pytest.mark.asyncio
async def test_end_user_transactions_reset():
# Setup
mock_client = MagicMock()
mock_client.end_user_list_transactons = {"1": 10.0} # Bad log
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
# Call function - should raise error
with pytest.raises(Exception):
await ProxyUpdateSpend.update_end_user_spend(
n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
)
# Verify cleanup happened
assert (
mock_client.end_user_list_transactons == {}
), "Transactions list should be empty after error"
@pytest.mark.asyncio
async def test_spend_logs_cleanup_after_error():
# Setup test data
mock_client = MagicMock()
mock_client.spend_log_transactions = [
{"id": 1, "amount": 10.0},
{"id": 2, "amount": 20.0},
{"id": 3, "amount": 30.0},
]
# Make the DB operation fail
mock_client.db.litellm_spendlogs.create_many = AsyncMock(
side_effect=Exception("DB Error")
)
original_logs = mock_client.spend_log_transactions.copy()
# Call function - should raise error
with pytest.raises(Exception):
await ProxyUpdateSpend.update_spend_logs(
n_retry_times=0,
prisma_client=mock_client,
db_writer_client=None, # Test DB write path
proxy_logging_obj=MagicMock(),
)
# Verify the first batch was removed from spend_log_transactions
assert (
mock_client.spend_log_transactions == original_logs[100:]
), "Should remove processed logs even after error"
def test_provider_specific_header():
from litellm.proxy.litellm_pre_call_utils import (
add_provider_specific_headers_to_request,

View file

@ -862,3 +862,18 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
await _jwt_auth_user_api_key_auth_builder(**args)
else:
await _jwt_auth_user_api_key_auth_builder(**args)
def test_user_api_key_auth_end_user_str():
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
user_api_key_args = {
"api_key": "sk-1234",
"parent_otel_span": None,
"user_role": LitellmUserRoles.PROXY_ADMIN,
"end_user_id": "1",
"user_id": "default_user_id",
}
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
assert user_api_key_auth.end_user_id == "1"