mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Ollama ssl verify = False + Spend Logs reliability fixes (#7931)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
* fix(http_handler.py): support passing ssl verify dynamically and using the correct httpx client based on passed ssl verify param
Fixes https://github.com/BerriAI/litellm/issues/6499
* feat(llm_http_handler.py): support passing `ssl_verify=False` dynamically in call args
Closes https://github.com/BerriAI/litellm/issues/6499
* fix(proxy/utils.py): prevent bad logs from breaking all cost tracking + reset list regardless of success/failure
prevents malformed logs from causing all spend tracking to break since they're constantly retried
* test(test_proxy_utils.py): add test to ensure bad log is dropped
* test(test_proxy_utils.py): ensure in-memory spend logs reset after bad log error
* test(test_user_api_key_auth.py): add unit test to ensure end user id as str works
* fix(auth_utils.py): ensure extracted end user id is always a str
prevents db cost tracking errors
* test(test_auth_utils.py): ensure get end user id from request body always returns a string
* test: update tests
* test: skip bedrock test- behaviour now supported
* test: fix testing
* refactor(spend_tracking_utils.py): reduce size of get_logging_payload
* test: fix test
* bump: version 1.59.4 → 1.59.5
* Revert "bump: version 1.59.4 → 1.59.5"
This reverts commit 1182b46b2e
.
* fix(utils.py): fix spend logs retry logic
* fix(spend_tracking_utils.py): fix get tags
* fix(spend_tracking_utils.py): fix end user id spend tracking on pass-through endpoints
This commit is contained in:
parent
851b0c4c4d
commit
1e011b66d3
17 changed files with 406 additions and 187 deletions
|
@ -93,11 +93,15 @@ class AsyncHTTPHandler:
|
|||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
|
||||
concurrent_limit=1000,
|
||||
client_alias: Optional[str] = None, # name for client in logs
|
||||
ssl_verify: Optional[Union[bool, str]] = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.event_hooks = event_hooks
|
||||
self.client = self.create_client(
|
||||
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks
|
||||
timeout=timeout,
|
||||
concurrent_limit=concurrent_limit,
|
||||
event_hooks=event_hooks,
|
||||
ssl_verify=ssl_verify,
|
||||
)
|
||||
self.client_alias = client_alias
|
||||
|
||||
|
@ -106,11 +110,13 @@ class AsyncHTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
concurrent_limit: int,
|
||||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
|
||||
ssl_verify: Optional[Union[bool, str]] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
# /path/to/certificate.pem
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
if ssl_verify is None:
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
# An SSL certificate used by the requested host to authenticate the client.
|
||||
# /path/to/client.pem
|
||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||
|
@ -440,13 +446,17 @@ class HTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
concurrent_limit=1000,
|
||||
client: Optional[httpx.Client] = None,
|
||||
ssl_verify: Optional[Union[bool, str]] = None,
|
||||
):
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_TIMEOUT
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
# /path/to/certificate.pem
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
|
||||
if ssl_verify is None:
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
|
||||
# An SSL certificate used by the requested host to authenticate the client.
|
||||
# /path/to/client.pem
|
||||
cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
|
||||
|
@ -506,7 +516,15 @@ class HTTPHandler:
|
|||
try:
|
||||
if timeout is not None:
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore
|
||||
"POST",
|
||||
url,
|
||||
data=data, # type: ignore
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
files=files,
|
||||
content=content, # type: ignore
|
||||
)
|
||||
else:
|
||||
req = self.client.build_request(
|
||||
|
@ -660,6 +678,7 @@ def get_async_httpx_client(
|
|||
_new_client = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
litellm.in_memory_llm_clients_cache.set_cache(
|
||||
key=_cache_key_name,
|
||||
value=_new_client,
|
||||
|
@ -684,6 +703,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
|
|||
pass
|
||||
|
||||
_cache_key_name = "httpx_client" + _params_key_name
|
||||
|
||||
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
|
||||
if _cached_client:
|
||||
return _cached_client
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -36,13 +36,13 @@ class HTTPHandler:
|
|||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
url, data=data, params=params, headers=headers
|
||||
url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
|
|
@ -158,7 +158,8 @@ class BaseLLMHTTPHandler:
|
|||
):
|
||||
if client is None:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider),
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
@ -318,7 +319,9 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
sync_httpx_client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
|
@ -359,7 +362,11 @@ class BaseLLMHTTPHandler:
|
|||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, dict]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
sync_httpx_client = _get_httpx_client(
|
||||
{
|
||||
"ssl_verify": litellm_params.get("ssl_verify", None),
|
||||
}
|
||||
)
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
stream = True
|
||||
|
@ -411,7 +418,7 @@ class BaseLLMHTTPHandler:
|
|||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
completion_stream, _response_headers = await self.make_async_call_stream_helper(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
|
@ -432,7 +439,7 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
return streamwrapper
|
||||
|
||||
async def make_async_call(
|
||||
async def make_async_call_stream_helper(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseConfig,
|
||||
|
@ -446,9 +453,15 @@ class BaseLLMHTTPHandler:
|
|||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
"""
|
||||
Helper function for making an async call with stream.
|
||||
|
||||
Handles fake stream as well.
|
||||
"""
|
||||
if client is None:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider),
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
|
|
@ -855,6 +855,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
organization = kwargs.get("organization", None)
|
||||
### VERIFY SSL ###
|
||||
ssl_verify = kwargs.get("ssl_verify", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -1102,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
drop_params=kwargs.get("drop_params"),
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
ssl_verify=ssl_verify,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
|
14
litellm/proxy/_new_new_secret_config.yaml
Normal file
14
litellm/proxy/_new_new_secret_config.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
model_list:
|
||||
- model_name: bedrock-claude
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_region_name: us-east-1
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["datadog"] # logs llm success + failure logs on datadog
|
||||
service_callback: ["datadog"] # logs redis, postgres failures on datadog
|
||||
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: true
|
|
@ -11,7 +11,3 @@ model_list:
|
|||
api_base: http://0.0.0.0:8090
|
||||
timeout: 2
|
||||
num_retries: 0
|
||||
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
|
@ -1483,7 +1483,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
# Check if the value is None and set the corresponding attribute
|
||||
if getattr(self, attr_name, None) is None:
|
||||
kwargs[attr_name] = value
|
||||
|
||||
if key == "end_user_id" and value is not None and isinstance(value, int):
|
||||
kwargs[key] = str(value)
|
||||
# Initialize the superclass
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
|
|
@ -498,13 +498,13 @@ def _has_user_setup_sso():
|
|||
|
||||
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
|
||||
# openai - check 'user'
|
||||
if "user" in request_body:
|
||||
return request_body["user"]
|
||||
if "user" in request_body and request_body["user"] is not None:
|
||||
return str(request_body["user"])
|
||||
# anthropic - check 'litellm_metadata'
|
||||
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
|
||||
if end_user_id:
|
||||
return end_user_id
|
||||
return str(end_user_id)
|
||||
metadata = request_body.get("metadata")
|
||||
if metadata and "user_id" in metadata:
|
||||
return metadata["user_id"]
|
||||
if metadata and "user_id" in metadata and metadata["user_id"] is not None:
|
||||
return str(metadata["user_id"])
|
||||
return None
|
||||
|
|
|
@ -1051,7 +1051,6 @@ async def update_database( # noqa: PLR0915
|
|||
response_obj=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
payload["spend"] = response_cost
|
||||
|
|
|
@ -10,6 +10,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
||||
|
@ -29,16 +30,37 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def get_logging_payload(
|
||||
kwargs, response_obj, start_time, end_time, end_user_id: Optional[str]
|
||||
) -> SpendLogsPayload:
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, master_key
|
||||
|
||||
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
user_api_key=None,
|
||||
user_api_key_alias=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
additional_usage_values=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n"
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
+ str(list(metadata.keys()))
|
||||
)
|
||||
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = SpendLogsMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in SpendLogsMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
return clean_metadata
|
||||
|
||||
|
||||
def get_logging_payload(kwargs, response_obj, start_time, end_time) -> SpendLogsPayload:
|
||||
from litellm.proxy.proxy_server import general_settings, master_key
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if response_obj is None or (
|
||||
|
@ -57,54 +79,34 @@ def get_logging_payload(
|
|||
if isinstance(usage, litellm.Usage):
|
||||
usage = dict(usage)
|
||||
id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id")
|
||||
api_key = metadata.get("user_api_key", "")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
standard_logging_payload = cast(
|
||||
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
|
||||
)
|
||||
if api_key is not None and isinstance(api_key, str):
|
||||
if api_key.startswith("sk-"):
|
||||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
if (
|
||||
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||
):
|
||||
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = metadata.get("model_group", "")
|
||||
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
if standard_logging_payload is not None:
|
||||
api_key = standard_logging_payload["metadata"].get("user_api_key_hash") or ""
|
||||
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
|
||||
"user_api_key_end_user_id"
|
||||
)
|
||||
else:
|
||||
api_key = ""
|
||||
request_tags = (
|
||||
json.dumps(metadata.get("tags", []))
|
||||
if isinstance(metadata.get("tags", []), list)
|
||||
else "[]"
|
||||
)
|
||||
if (
|
||||
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||
):
|
||||
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = metadata.get("model_group", "")
|
||||
|
||||
# clean up litellm metadata
|
||||
clean_metadata = SpendLogsMetadata(
|
||||
user_api_key=None,
|
||||
user_api_key_alias=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
additional_usage_values=None,
|
||||
)
|
||||
if isinstance(metadata, dict):
|
||||
verbose_proxy_logger.debug(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
+ str(list(metadata.keys()))
|
||||
)
|
||||
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = SpendLogsMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata[key]
|
||||
for key in SpendLogsMetadata.__annotations__.keys()
|
||||
if key in metadata
|
||||
}
|
||||
)
|
||||
clean_metadata = _get_spend_logs_metadata(metadata)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
additional_usage_values = {}
|
||||
|
|
|
@ -2432,6 +2432,126 @@ async def reset_budget(prisma_client: PrismaClient):
|
|||
)
|
||||
|
||||
|
||||
class ProxyUpdateSpend:
|
||||
@staticmethod
|
||||
async def update_end_user_spend(
|
||||
n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging
|
||||
):
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
end_user_id,
|
||||
response_cost,
|
||||
) in prisma_client.end_user_list_transactons.items():
|
||||
if litellm.max_end_user_budget is not None:
|
||||
pass
|
||||
batcher.litellm_endusertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": end_user_id,
|
||||
"spend": response_cost,
|
||||
"blocked": False,
|
||||
},
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
finally:
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # reset the end user list transactions - prevent bad data from causing issues
|
||||
|
||||
@staticmethod
|
||||
async def update_spend_logs(
|
||||
n_retry_times: int,
|
||||
prisma_client: PrismaClient,
|
||||
db_writer_client: Optional[HTTPHandler],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
|
||||
MAX_LOGS_PER_INTERVAL = (
|
||||
1000 # Maximum number of logs to flush in a single interval
|
||||
)
|
||||
# Get initial logs to process
|
||||
logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL]
|
||||
start_time = time.time()
|
||||
try:
|
||||
for i in range(n_retry_times + 1):
|
||||
try:
|
||||
base_url = os.getenv("SPEND_LOGS_URL", None)
|
||||
if (
|
||||
len(logs_to_process) > 0
|
||||
and base_url is not None
|
||||
and db_writer_client is not None
|
||||
):
|
||||
if not base_url.endswith("/"):
|
||||
base_url += "/"
|
||||
verbose_proxy_logger.debug("base_url: {}".format(base_url))
|
||||
response = await db_writer_client.post(
|
||||
url=base_url + "spend/update",
|
||||
data=json.dumps(logs_to_process),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
prisma_client.spend_log_transactions = (
|
||||
prisma_client.spend_log_transactions[
|
||||
len(logs_to_process) :
|
||||
]
|
||||
)
|
||||
else:
|
||||
for j in range(0, len(logs_to_process), BATCH_SIZE):
|
||||
batch = logs_to_process[j : j + BATCH_SIZE]
|
||||
batch_with_dates = [
|
||||
prisma_client.jsonify_object({**entry})
|
||||
for entry in batch
|
||||
]
|
||||
await prisma_client.db.litellm_spendlogs.create_many(
|
||||
data=batch_with_dates, skip_duplicates=True
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Flushed {len(batch)} logs to the DB."
|
||||
)
|
||||
|
||||
prisma_client.spend_log_transactions = (
|
||||
prisma_client.spend_log_transactions[len(logs_to_process) :]
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
|
||||
)
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES:
|
||||
if i is None:
|
||||
i = 0
|
||||
if i >= n_retry_times:
|
||||
raise
|
||||
await asyncio.sleep(2**i)
|
||||
except Exception as e:
|
||||
prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[
|
||||
len(logs_to_process) :
|
||||
]
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
|
||||
async def update_spend( # noqa: PLR0915
|
||||
prisma_client: PrismaClient,
|
||||
db_writer_client: Optional[HTTPHandler],
|
||||
|
@ -2490,47 +2610,11 @@ async def update_spend( # noqa: PLR0915
|
|||
)
|
||||
)
|
||||
if len(prisma_client.end_user_list_transactons.keys()) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with prisma_client.db.tx(
|
||||
timeout=timedelta(seconds=60)
|
||||
) as transaction:
|
||||
async with transaction.batch_() as batcher:
|
||||
for (
|
||||
end_user_id,
|
||||
response_cost,
|
||||
) in prisma_client.end_user_list_transactons.items():
|
||||
if litellm.max_end_user_budget is not None:
|
||||
pass
|
||||
batcher.litellm_endusertable.upsert(
|
||||
where={"user_id": end_user_id},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": end_user_id,
|
||||
"spend": response_cost,
|
||||
"blocked": False,
|
||||
},
|
||||
"update": {"spend": {"increment": response_cost}},
|
||||
},
|
||||
)
|
||||
|
||||
prisma_client.end_user_list_transactons = (
|
||||
{}
|
||||
) # Clear the remaining transactions after processing all batches in the loop.
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
### UPDATE KEY TABLE ###
|
||||
verbose_proxy_logger.debug(
|
||||
"KEY Spend transactions: {}".format(
|
||||
|
@ -2687,80 +2771,13 @@ async def update_spend( # noqa: PLR0915
|
|||
"Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions))
|
||||
)
|
||||
|
||||
BATCH_SIZE = 100 # Preferred size of each batch to write to the database
|
||||
MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval
|
||||
|
||||
if len(prisma_client.spend_log_transactions) > 0:
|
||||
for i in range(n_retry_times + 1):
|
||||
start_time = time.time()
|
||||
try:
|
||||
base_url = os.getenv("SPEND_LOGS_URL", None)
|
||||
## WRITE TO SEPARATE SERVER ##
|
||||
if (
|
||||
len(prisma_client.spend_log_transactions) > 0
|
||||
and base_url is not None
|
||||
and db_writer_client is not None
|
||||
):
|
||||
if not base_url.endswith("/"):
|
||||
base_url += "/"
|
||||
verbose_proxy_logger.debug("base_url: {}".format(base_url))
|
||||
response = await db_writer_client.post(
|
||||
url=base_url + "spend/update",
|
||||
data=json.dumps(prisma_client.spend_log_transactions), # type: ignore
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
prisma_client.spend_log_transactions = []
|
||||
else: ## (default) WRITE TO DB ##
|
||||
logs_to_process = prisma_client.spend_log_transactions[
|
||||
:MAX_LOGS_PER_INTERVAL
|
||||
]
|
||||
for j in range(0, len(logs_to_process), BATCH_SIZE):
|
||||
# Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE
|
||||
batch = logs_to_process[j : j + BATCH_SIZE]
|
||||
|
||||
# Convert datetime strings to Date objects
|
||||
batch_with_dates = [
|
||||
prisma_client.jsonify_object(
|
||||
{
|
||||
**entry,
|
||||
}
|
||||
)
|
||||
for entry in batch
|
||||
]
|
||||
|
||||
await prisma_client.db.litellm_spendlogs.create_many(
|
||||
data=batch_with_dates, skip_duplicates=True # type: ignore
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Flushed {len(batch)} logs to the DB."
|
||||
)
|
||||
# Remove the processed logs from spend_logs
|
||||
prisma_client.spend_log_transactions = (
|
||||
prisma_client.spend_log_transactions[len(logs_to_process) :]
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}"
|
||||
)
|
||||
break
|
||||
except DB_CONNECTION_ERROR_TYPES as e:
|
||||
if i is None:
|
||||
i = 0
|
||||
if (
|
||||
i >= n_retry_times
|
||||
): # If we've reached the maximum number of retries raise the exception
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # type: ignore
|
||||
except Exception as e:
|
||||
_raise_failed_update_spend_exception(
|
||||
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
await ProxyUpdateSpend.update_spend_logs(
|
||||
n_retry_times=n_retry_times,
|
||||
prisma_client=prisma_client,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
db_writer_client=db_writer_client,
|
||||
)
|
||||
|
||||
|
||||
def _raise_failed_update_spend_exception(
|
||||
|
|
|
@ -2132,6 +2132,7 @@ def get_litellm_params(
|
|||
prompt_id: Optional[str] = None,
|
||||
prompt_variables: Optional[dict] = None,
|
||||
async_call: Optional[bool] = None,
|
||||
ssl_verify: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
litellm_params = {
|
||||
|
@ -2170,6 +2171,7 @@ def get_litellm_params(
|
|||
"prompt_id": prompt_id,
|
||||
"prompt_variables": prompt_variables,
|
||||
"async_call": async_call,
|
||||
"ssl_verify": ssl_verify,
|
||||
}
|
||||
return litellm_params
|
||||
|
||||
|
|
|
@ -68,3 +68,12 @@ def test_configurable_clientside_parameters(
|
|||
)
|
||||
print(resp)
|
||||
assert resp == should_return_true
|
||||
|
||||
|
||||
def test_get_end_user_id_from_request_body_always_returns_str():
|
||||
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
|
||||
|
||||
request_body = {"user": 123}
|
||||
end_user_id = get_end_user_id_from_request_body(request_body)
|
||||
assert end_user_id == "123"
|
||||
assert isinstance(end_user_id, str)
|
||||
|
|
|
@ -174,3 +174,67 @@ def test_ollama_chat_function_calling():
|
|||
print(json.loads(tool_calls[0].function.arguments))
|
||||
|
||||
print(response)
|
||||
|
||||
|
||||
def test_ollama_ssl_verify():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import ssl
|
||||
import httpx
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="ollama/llama3.1",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco?",
|
||||
}
|
||||
],
|
||||
ssl_verify=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
client: HTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
|
||||
"httpx_clientssl_verify_False"
|
||||
)
|
||||
|
||||
test_client = httpx.Client(verify=False)
|
||||
print(client)
|
||||
assert (
|
||||
client.client._transport._pool._ssl_context.verify_mode
|
||||
== test_client._transport._pool._ssl_context.verify_mode
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_ollama_ssl_verify(stream):
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
import httpx
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="ollama/llama3.1",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco?",
|
||||
}
|
||||
],
|
||||
ssl_verify=False,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
client: AsyncHTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
|
||||
"async_httpx_clientssl_verify_Falseollama"
|
||||
)
|
||||
|
||||
test_client = httpx.AsyncClient(verify=False)
|
||||
print(client)
|
||||
assert (
|
||||
client.client._transport._pool._ssl_context.verify_mode
|
||||
== test_client._transport._pool._ssl_context.verify_mode
|
||||
)
|
||||
|
|
|
@ -170,6 +170,12 @@ def test_spend_logs_payload(model_id: Optional[str]):
|
|||
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
|
||||
"cache_hit": None,
|
||||
"response_cost": 2.4999999999999998e-05,
|
||||
"standard_logging_object": {
|
||||
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
|
||||
"metadata": {
|
||||
"user_api_key_end_user_id": "test-user",
|
||||
},
|
||||
},
|
||||
},
|
||||
"response_obj": litellm.ModelResponse(
|
||||
id=model_id,
|
||||
|
@ -192,7 +198,6 @@ def test_spend_logs_payload(model_id: Optional[str]):
|
|||
),
|
||||
"start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604),
|
||||
"end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146),
|
||||
"end_user_id": None,
|
||||
}
|
||||
|
||||
payload: SpendLogsPayload = get_logging_payload(**input_args)
|
||||
|
@ -229,6 +234,7 @@ def test_spend_logs_payload_whisper():
|
|||
"metadata": {
|
||||
"user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||
"user_api_key_alias": None,
|
||||
"user_api_key_end_user_id": "test-user",
|
||||
"user_api_end_user_max_budget": None,
|
||||
"litellm_api_version": "1.40.19",
|
||||
"global_max_parallel_requests": None,
|
||||
|
@ -293,7 +299,6 @@ def test_spend_logs_payload_whisper():
|
|||
response_obj=response,
|
||||
start_time=datetime.datetime.now(),
|
||||
end_time=datetime.datetime.now(),
|
||||
end_user_id="test-user",
|
||||
)
|
||||
|
||||
print("payload: ", payload)
|
||||
|
@ -335,13 +340,16 @@ def test_spend_logs_payload_with_prompts_enabled(monkeypatch):
|
|||
),
|
||||
"start_time": datetime.datetime.now(),
|
||||
"end_time": datetime.datetime.now(),
|
||||
"end_user_id": "user123",
|
||||
}
|
||||
|
||||
# Create a standard logging payload
|
||||
standard_logging_payload = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"response": {"role": "assistant", "content": "Hi there!"},
|
||||
"metadata": {
|
||||
"user_api_key_end_user_id": "test-user",
|
||||
},
|
||||
"request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"],
|
||||
}
|
||||
input_args["kwargs"]["standard_logging_object"] = standard_logging_payload
|
||||
|
||||
|
|
|
@ -1497,6 +1497,62 @@ def test_custom_openapi(mock_get_openapi_schema):
|
|||
assert openapi_schema is not None
|
||||
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_user_transactions_reset():
|
||||
# Setup
|
||||
mock_client = MagicMock()
|
||||
mock_client.end_user_list_transactons = {"1": 10.0} # Bad log
|
||||
mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error"))
|
||||
|
||||
# Call function - should raise error
|
||||
with pytest.raises(Exception):
|
||||
await ProxyUpdateSpend.update_end_user_spend(
|
||||
n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock()
|
||||
)
|
||||
|
||||
# Verify cleanup happened
|
||||
assert (
|
||||
mock_client.end_user_list_transactons == {}
|
||||
), "Transactions list should be empty after error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_logs_cleanup_after_error():
|
||||
# Setup test data
|
||||
mock_client = MagicMock()
|
||||
mock_client.spend_log_transactions = [
|
||||
{"id": 1, "amount": 10.0},
|
||||
{"id": 2, "amount": 20.0},
|
||||
{"id": 3, "amount": 30.0},
|
||||
]
|
||||
# Make the DB operation fail
|
||||
mock_client.db.litellm_spendlogs.create_many = AsyncMock(
|
||||
side_effect=Exception("DB Error")
|
||||
)
|
||||
|
||||
original_logs = mock_client.spend_log_transactions.copy()
|
||||
|
||||
# Call function - should raise error
|
||||
with pytest.raises(Exception):
|
||||
await ProxyUpdateSpend.update_spend_logs(
|
||||
n_retry_times=0,
|
||||
prisma_client=mock_client,
|
||||
db_writer_client=None, # Test DB write path
|
||||
proxy_logging_obj=MagicMock(),
|
||||
)
|
||||
|
||||
# Verify the first batch was removed from spend_log_transactions
|
||||
assert (
|
||||
mock_client.spend_log_transactions == original_logs[100:]
|
||||
), "Should remove processed logs even after error"
|
||||
|
||||
def test_provider_specific_header():
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
add_provider_specific_headers_to_request,
|
||||
|
|
|
@ -862,3 +862,18 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa
|
|||
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||
else:
|
||||
await _jwt_auth_user_api_key_auth_builder(**args)
|
||||
|
||||
|
||||
def test_user_api_key_auth_end_user_str():
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth
|
||||
|
||||
user_api_key_args = {
|
||||
"api_key": "sk-1234",
|
||||
"parent_otel_span": None,
|
||||
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||
"end_user_id": "1",
|
||||
"user_id": "default_user_id",
|
||||
}
|
||||
|
||||
user_api_key_auth = UserAPIKeyAuth(**user_api_key_args)
|
||||
assert user_api_key_auth.end_user_id == "1"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue