diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 469bb693fb..517cad25b0 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -93,11 +93,15 @@ class AsyncHTTPHandler: event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None, concurrent_limit=1000, client_alias: Optional[str] = None, # name for client in logs + ssl_verify: Optional[Union[bool, str]] = None, ): self.timeout = timeout self.event_hooks = event_hooks self.client = self.create_client( - timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks + timeout=timeout, + concurrent_limit=concurrent_limit, + event_hooks=event_hooks, + ssl_verify=ssl_verify, ) self.client_alias = client_alias @@ -106,11 +110,13 @@ class AsyncHTTPHandler: timeout: Optional[Union[float, httpx.Timeout]], concurrent_limit: int, event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], + ssl_verify: Optional[Union[bool, str]] = None, ) -> httpx.AsyncClient: # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem - ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + if ssl_verify is None: + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) # An SSL certificate used by the requested host to authenticate the client. # /path/to/client.pem cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) @@ -440,13 +446,17 @@ class HTTPHandler: timeout: Optional[Union[float, httpx.Timeout]] = None, concurrent_limit=1000, client: Optional[httpx.Client] = None, + ssl_verify: Optional[Union[bool, str]] = None, ): if timeout is None: timeout = _DEFAULT_TIMEOUT # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # /path/to/certificate.pem - ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + + if ssl_verify is None: + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + # An SSL certificate used by the requested host to authenticate the client. # /path/to/client.pem cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) @@ -506,7 +516,15 @@ class HTTPHandler: try: if timeout is not None: req = self.client.build_request( - "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout, files=files, content=content # type: ignore + "POST", + url, + data=data, # type: ignore + json=json, + params=params, + headers=headers, + timeout=timeout, + files=files, + content=content, # type: ignore ) else: req = self.client.build_request( @@ -660,6 +678,7 @@ def get_async_httpx_client( _new_client = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) + litellm.in_memory_llm_clients_cache.set_cache( key=_cache_key_name, value=_new_client, @@ -684,6 +703,7 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: pass _cache_key_name = "httpx_client" + _params_key_name + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) if _cached_client: return _cached_client diff --git a/litellm/llms/custom_httpx/httpx_handler.py b/litellm/llms/custom_httpx/httpx_handler.py index bd5e0d334f..6f684ba01c 100644 --- a/litellm/llms/custom_httpx/httpx_handler.py +++ b/litellm/llms/custom_httpx/httpx_handler.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import httpx @@ -36,13 +36,13 @@ class HTTPHandler: async def post( self, url: str, - data: Optional[dict] = None, + data: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, ): try: response = await self.client.post( - url, data=data, params=params, headers=headers + url, data=data, params=params, headers=headers # type: ignore ) return response except Exception as e: diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index c7ba9cd096..71a8a8168b 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -158,7 +158,8 @@ class BaseLLMHTTPHandler: ): if client is None: async_httpx_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders(custom_llm_provider) + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, ) else: async_httpx_client = client @@ -318,7 +319,9 @@ class BaseLLMHTTPHandler: ) if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} + ) else: sync_httpx_client = client @@ -359,7 +362,11 @@ class BaseLLMHTTPHandler: client: Optional[HTTPHandler] = None, ) -> Tuple[Any, dict]: if client is None or not isinstance(client, HTTPHandler): - sync_httpx_client = _get_httpx_client() + sync_httpx_client = _get_httpx_client( + { + "ssl_verify": litellm_params.get("ssl_verify", None), + } + ) else: sync_httpx_client = client stream = True @@ -411,7 +418,7 @@ class BaseLLMHTTPHandler: fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, ): - completion_stream, _response_headers = await self.make_async_call( + completion_stream, _response_headers = await self.make_async_call_stream_helper( custom_llm_provider=custom_llm_provider, provider_config=provider_config, api_base=api_base, @@ -432,7 +439,7 @@ class BaseLLMHTTPHandler: ) return streamwrapper - async def make_async_call( + async def make_async_call_stream_helper( self, custom_llm_provider: str, provider_config: BaseConfig, @@ -446,9 +453,15 @@ class BaseLLMHTTPHandler: fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, ) -> Tuple[Any, httpx.Headers]: + """ + Helper function for making an async call with stream. + + Handles fake stream as well. + """ if client is None: async_httpx_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders(custom_llm_provider) + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, ) else: async_httpx_client = client diff --git a/litellm/main.py b/litellm/main.py index 37ef978642..93cf16c601 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -855,6 +855,8 @@ def completion( # type: ignore # noqa: PLR0915 cooldown_time = kwargs.get("cooldown_time", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) organization = kwargs.get("organization", None) + ### VERIFY SSL ### + ssl_verify = kwargs.get("ssl_verify", None) ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) @@ -1102,6 +1104,7 @@ def completion( # type: ignore # noqa: PLR0915 drop_params=kwargs.get("drop_params"), prompt_id=prompt_id, prompt_variables=prompt_variables, + ssl_verify=ssl_verify, ) logging.update_environment_variables( model=model, diff --git a/litellm/proxy/_new_new_secret_config.yaml b/litellm/proxy/_new_new_secret_config.yaml new file mode 100644 index 0000000000..7932cc20fe --- /dev/null +++ b/litellm/proxy/_new_new_secret_config.yaml @@ -0,0 +1,14 @@ +model_list: + - model_name: bedrock-claude + litellm_params: + model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0 + aws_region_name: us-east-1 + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + +litellm_settings: + callbacks: ["datadog"] # logs llm success + failure logs on datadog + service_callback: ["datadog"] # logs redis, postgres failures on datadog + +general_settings: + store_prompts_in_spend_logs: true diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 71a3695c58..209e86149d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -11,7 +11,3 @@ model_list: api_base: http://0.0.0.0:8090 timeout: 2 num_retries: 0 - - -litellm_settings: - success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2b1158213f..e68d92cee6 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1483,7 +1483,8 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): # Check if the value is None and set the corresponding attribute if getattr(self, attr_name, None) is None: kwargs[attr_name] = value - + if key == "end_user_id" and value is not None and isinstance(value, int): + kwargs[key] = str(value) # Initialize the superclass super().__init__(**kwargs) diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 13afe992f0..92f0d49654 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -498,13 +498,13 @@ def _has_user_setup_sso(): def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]: # openai - check 'user' - if "user" in request_body: - return request_body["user"] + if "user" in request_body and request_body["user"] is not None: + return str(request_body["user"]) # anthropic - check 'litellm_metadata' end_user_id = request_body.get("litellm_metadata", {}).get("user", None) if end_user_id: - return end_user_id + return str(end_user_id) metadata = request_body.get("metadata") - if metadata and "user_id" in metadata: - return metadata["user_id"] + if metadata and "user_id" in metadata and metadata["user_id"] is not None: + return str(metadata["user_id"]) return None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4fb589fe37..3a96205155 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1051,7 +1051,6 @@ async def update_database( # noqa: PLR0915 response_obj=completion_response, start_time=start_time, end_time=end_time, - end_user_id=end_user_id, ) payload["spend"] = response_cost diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index c48ff105c0..bf95ceffbd 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -10,6 +10,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy.utils import PrismaClient, hash_token from litellm.types.utils import StandardLoggingPayload +from litellm.utils import get_end_user_id_for_cost_tracking def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: @@ -29,16 +30,37 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: return False -def get_logging_payload( - kwargs, response_obj, start_time, end_time, end_user_id: Optional[str] -) -> SpendLogsPayload: - - from litellm.proxy.proxy_server import general_settings, master_key - +def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata: + if metadata is None: + return SpendLogsMetadata( + user_api_key=None, + user_api_key_alias=None, + user_api_key_team_id=None, + user_api_key_user_id=None, + user_api_key_team_alias=None, + spend_logs_metadata=None, + requester_ip_address=None, + additional_usage_values=None, + ) verbose_proxy_logger.debug( - f"SpendTable: get_logging_payload - kwargs: {kwargs}\n\n" + "getting payload for SpendLogs, available keys in metadata: " + + str(list(metadata.keys())) ) + # Filter the metadata dictionary to include only the specified keys + clean_metadata = SpendLogsMetadata( + **{ # type: ignore + key: metadata[key] + for key in SpendLogsMetadata.__annotations__.keys() + if key in metadata + } + ) + return clean_metadata + + +def get_logging_payload(kwargs, response_obj, start_time, end_time) -> SpendLogsPayload: + from litellm.proxy.proxy_server import general_settings, master_key + if kwargs is None: kwargs = {} if response_obj is None or ( @@ -57,54 +79,34 @@ def get_logging_payload( if isinstance(usage, litellm.Usage): usage = dict(usage) id = cast(dict, response_obj).get("id") or kwargs.get("litellm_call_id") - api_key = metadata.get("user_api_key", "") - standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None + standard_logging_payload = cast( + Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None) ) - if api_key is not None and isinstance(api_key, str): - if api_key.startswith("sk-"): - # hash the api_key - api_key = hash_token(api_key) - if ( - _is_master_key(api_key=api_key, _master_key=master_key) - and general_settings.get("disable_adding_master_key_hash_to_db") is True - ): - api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db - - _model_id = metadata.get("model_info", {}).get("id", "") - _model_group = metadata.get("model_group", "") + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + if standard_logging_payload is not None: + api_key = standard_logging_payload["metadata"].get("user_api_key_hash") or "" + end_user_id = end_user_id or standard_logging_payload["metadata"].get( + "user_api_key_end_user_id" + ) + else: + api_key = "" request_tags = ( json.dumps(metadata.get("tags", [])) if isinstance(metadata.get("tags", []), list) else "[]" ) + if ( + _is_master_key(api_key=api_key, _master_key=master_key) + and general_settings.get("disable_adding_master_key_hash_to_db") is True + ): + api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db + + _model_id = metadata.get("model_info", {}).get("id", "") + _model_group = metadata.get("model_group", "") # clean up litellm metadata - clean_metadata = SpendLogsMetadata( - user_api_key=None, - user_api_key_alias=None, - user_api_key_team_id=None, - user_api_key_user_id=None, - user_api_key_team_alias=None, - spend_logs_metadata=None, - requester_ip_address=None, - additional_usage_values=None, - ) - if isinstance(metadata, dict): - verbose_proxy_logger.debug( - "getting payload for SpendLogs, available keys in metadata: " - + str(list(metadata.keys())) - ) - - # Filter the metadata dictionary to include only the specified keys - clean_metadata = SpendLogsMetadata( - **{ # type: ignore - key: metadata[key] - for key in SpendLogsMetadata.__annotations__.keys() - if key in metadata - } - ) + clean_metadata = _get_spend_logs_metadata(metadata) special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] additional_usage_values = {} diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7843fe916d..40080a107c 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2432,6 +2432,126 @@ async def reset_budget(prisma_client: PrismaClient): ) +class ProxyUpdateSpend: + @staticmethod + async def update_end_user_spend( + n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging + ): + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + end_user_id, + response_cost, + ) in prisma_client.end_user_list_transactons.items(): + if litellm.max_end_user_budget is not None: + pass + batcher.litellm_endusertable.upsert( + where={"user_id": end_user_id}, + data={ + "create": { + "user_id": end_user_id, + "spend": response_cost, + "blocked": False, + }, + "update": {"spend": {"increment": response_cost}}, + }, + ) + + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + finally: + prisma_client.end_user_list_transactons = ( + {} + ) # reset the end user list transactions - prevent bad data from causing issues + + @staticmethod + async def update_spend_logs( + n_retry_times: int, + prisma_client: PrismaClient, + db_writer_client: Optional[HTTPHandler], + proxy_logging_obj: ProxyLogging, + ): + BATCH_SIZE = 100 # Preferred size of each batch to write to the database + MAX_LOGS_PER_INTERVAL = ( + 1000 # Maximum number of logs to flush in a single interval + ) + # Get initial logs to process + logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL] + start_time = time.time() + try: + for i in range(n_retry_times + 1): + try: + base_url = os.getenv("SPEND_LOGS_URL", None) + if ( + len(logs_to_process) > 0 + and base_url is not None + and db_writer_client is not None + ): + if not base_url.endswith("/"): + base_url += "/" + verbose_proxy_logger.debug("base_url: {}".format(base_url)) + response = await db_writer_client.post( + url=base_url + "spend/update", + data=json.dumps(logs_to_process), + headers={"Content-Type": "application/json"}, + ) + if response.status_code == 200: + prisma_client.spend_log_transactions = ( + prisma_client.spend_log_transactions[ + len(logs_to_process) : + ] + ) + else: + for j in range(0, len(logs_to_process), BATCH_SIZE): + batch = logs_to_process[j : j + BATCH_SIZE] + batch_with_dates = [ + prisma_client.jsonify_object({**entry}) + for entry in batch + ] + await prisma_client.db.litellm_spendlogs.create_many( + data=batch_with_dates, skip_duplicates=True + ) + verbose_proxy_logger.debug( + f"Flushed {len(batch)} logs to the DB." + ) + + prisma_client.spend_log_transactions = ( + prisma_client.spend_log_transactions[len(logs_to_process) :] + ) + verbose_proxy_logger.debug( + f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" + ) + break + except DB_CONNECTION_ERROR_TYPES: + if i is None: + i = 0 + if i >= n_retry_times: + raise + await asyncio.sleep(2**i) + except Exception as e: + prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[ + len(logs_to_process) : + ] + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + async def update_spend( # noqa: PLR0915 prisma_client: PrismaClient, db_writer_client: Optional[HTTPHandler], @@ -2490,47 +2610,11 @@ async def update_spend( # noqa: PLR0915 ) ) if len(prisma_client.end_user_list_transactons.keys()) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - async with prisma_client.db.tx( - timeout=timedelta(seconds=60) - ) as transaction: - async with transaction.batch_() as batcher: - for ( - end_user_id, - response_cost, - ) in prisma_client.end_user_list_transactons.items(): - if litellm.max_end_user_budget is not None: - pass - batcher.litellm_endusertable.upsert( - where={"user_id": end_user_id}, - data={ - "create": { - "user_id": end_user_id, - "spend": response_cost, - "blocked": False, - }, - "update": {"spend": {"increment": response_cost}}, - }, - ) - - prisma_client.end_user_list_transactons = ( - {} - ) # Clear the remaining transactions after processing all batches in the loop. - break - except DB_CONNECTION_ERROR_TYPES as e: - if i >= n_retry_times: # If we've reached the maximum number of retries - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - + await ProxyUpdateSpend.update_end_user_spend( + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + ) ### UPDATE KEY TABLE ### verbose_proxy_logger.debug( "KEY Spend transactions: {}".format( @@ -2687,80 +2771,13 @@ async def update_spend( # noqa: PLR0915 "Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions)) ) - BATCH_SIZE = 100 # Preferred size of each batch to write to the database - MAX_LOGS_PER_INTERVAL = 1000 # Maximum number of logs to flush in a single interval - if len(prisma_client.spend_log_transactions) > 0: - for i in range(n_retry_times + 1): - start_time = time.time() - try: - base_url = os.getenv("SPEND_LOGS_URL", None) - ## WRITE TO SEPARATE SERVER ## - if ( - len(prisma_client.spend_log_transactions) > 0 - and base_url is not None - and db_writer_client is not None - ): - if not base_url.endswith("/"): - base_url += "/" - verbose_proxy_logger.debug("base_url: {}".format(base_url)) - response = await db_writer_client.post( - url=base_url + "spend/update", - data=json.dumps(prisma_client.spend_log_transactions), # type: ignore - headers={"Content-Type": "application/json"}, - ) - if response.status_code == 200: - prisma_client.spend_log_transactions = [] - else: ## (default) WRITE TO DB ## - logs_to_process = prisma_client.spend_log_transactions[ - :MAX_LOGS_PER_INTERVAL - ] - for j in range(0, len(logs_to_process), BATCH_SIZE): - # Create sublist for current batch, ensuring it doesn't exceed the BATCH_SIZE - batch = logs_to_process[j : j + BATCH_SIZE] - - # Convert datetime strings to Date objects - batch_with_dates = [ - prisma_client.jsonify_object( - { - **entry, - } - ) - for entry in batch - ] - - await prisma_client.db.litellm_spendlogs.create_many( - data=batch_with_dates, skip_duplicates=True # type: ignore - ) - - verbose_proxy_logger.debug( - f"Flushed {len(batch)} logs to the DB." - ) - # Remove the processed logs from spend_logs - prisma_client.spend_log_transactions = ( - prisma_client.spend_log_transactions[len(logs_to_process) :] - ) - - verbose_proxy_logger.debug( - f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" - ) - break - except DB_CONNECTION_ERROR_TYPES as e: - if i is None: - i = 0 - if ( - i >= n_retry_times - ): # If we've reached the maximum number of retries raise the exception - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) - - # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # type: ignore - except Exception as e: - _raise_failed_update_spend_exception( - e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj - ) + await ProxyUpdateSpend.update_spend_logs( + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + db_writer_client=db_writer_client, + ) def _raise_failed_update_spend_exception( diff --git a/litellm/utils.py b/litellm/utils.py index a9cf50c5c6..dd43355f01 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2132,6 +2132,7 @@ def get_litellm_params( prompt_id: Optional[str] = None, prompt_variables: Optional[dict] = None, async_call: Optional[bool] = None, + ssl_verify: Optional[bool] = None, **kwargs, ) -> dict: litellm_params = { @@ -2170,6 +2171,7 @@ def get_litellm_params( "prompt_id": prompt_id, "prompt_variables": prompt_variables, "async_call": async_call, + "ssl_verify": ssl_verify, } return litellm_params diff --git a/tests/local_testing/test_auth_utils.py b/tests/local_testing/test_auth_utils.py index 1118b8a63b..73abedb3f0 100644 --- a/tests/local_testing/test_auth_utils.py +++ b/tests/local_testing/test_auth_utils.py @@ -68,3 +68,12 @@ def test_configurable_clientside_parameters( ) print(resp) assert resp == should_return_true + + +def test_get_end_user_id_from_request_body_always_returns_str(): + from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body + + request_body = {"user": 123} + end_user_id = get_end_user_id_from_request_body(request_body) + assert end_user_id == "123" + assert isinstance(end_user_id, str) diff --git a/tests/local_testing/test_ollama.py b/tests/local_testing/test_ollama.py index 2066859091..81cd331263 100644 --- a/tests/local_testing/test_ollama.py +++ b/tests/local_testing/test_ollama.py @@ -174,3 +174,67 @@ def test_ollama_chat_function_calling(): print(json.loads(tool_calls[0].function.arguments)) print(response) + + +def test_ollama_ssl_verify(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + import ssl + import httpx + + try: + response = litellm.completion( + model="ollama/llama3.1", + messages=[ + { + "role": "user", + "content": "What's the weather like in San Francisco?", + } + ], + ssl_verify=False, + ) + except Exception as e: + print(e) + + client: HTTPHandler = litellm.in_memory_llm_clients_cache.get_cache( + "httpx_clientssl_verify_False" + ) + + test_client = httpx.Client(verify=False) + print(client) + assert ( + client.client._transport._pool._ssl_context.verify_mode + == test_client._transport._pool._ssl_context.verify_mode + ) + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_async_ollama_ssl_verify(stream): + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + import httpx + + try: + response = await litellm.acompletion( + model="ollama/llama3.1", + messages=[ + { + "role": "user", + "content": "What's the weather like in San Francisco?", + } + ], + ssl_verify=False, + stream=stream, + ) + except Exception as e: + print(e) + + client: AsyncHTTPHandler = litellm.in_memory_llm_clients_cache.get_cache( + "async_httpx_clientssl_verify_Falseollama" + ) + + test_client = httpx.AsyncClient(verify=False) + print(client) + assert ( + client.client._transport._pool._ssl_context.verify_mode + == test_client._transport._pool._ssl_context.verify_mode + ) diff --git a/tests/logging_callback_tests/test_spend_logs.py b/tests/logging_callback_tests/test_spend_logs.py index 9ecfe47046..faad534cec 100644 --- a/tests/logging_callback_tests/test_spend_logs.py +++ b/tests/logging_callback_tests/test_spend_logs.py @@ -170,6 +170,12 @@ def test_spend_logs_payload(model_id: Optional[str]): "end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146), "cache_hit": None, "response_cost": 2.4999999999999998e-05, + "standard_logging_object": { + "request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"], + "metadata": { + "user_api_key_end_user_id": "test-user", + }, + }, }, "response_obj": litellm.ModelResponse( id=model_id, @@ -192,7 +198,6 @@ def test_spend_logs_payload(model_id: Optional[str]): ), "start_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 308604), "end_time": datetime.datetime(2024, 6, 7, 12, 43, 30, 954146), - "end_user_id": None, } payload: SpendLogsPayload = get_logging_payload(**input_args) @@ -229,6 +234,7 @@ def test_spend_logs_payload_whisper(): "metadata": { "user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", "user_api_key_alias": None, + "user_api_key_end_user_id": "test-user", "user_api_end_user_max_budget": None, "litellm_api_version": "1.40.19", "global_max_parallel_requests": None, @@ -293,7 +299,6 @@ def test_spend_logs_payload_whisper(): response_obj=response, start_time=datetime.datetime.now(), end_time=datetime.datetime.now(), - end_user_id="test-user", ) print("payload: ", payload) @@ -335,13 +340,16 @@ def test_spend_logs_payload_with_prompts_enabled(monkeypatch): ), "start_time": datetime.datetime.now(), "end_time": datetime.datetime.now(), - "end_user_id": "user123", } # Create a standard logging payload standard_logging_payload = { "messages": [{"role": "user", "content": "Hello!"}], "response": {"role": "assistant", "content": "Hi there!"}, + "metadata": { + "user_api_key_end_user_id": "test-user", + }, + "request_tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"], } input_args["kwargs"]["standard_logging_object"] = standard_logging_payload diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index d736067612..3f0b127af4 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1497,6 +1497,62 @@ def test_custom_openapi(mock_get_openapi_schema): assert openapi_schema is not None +import pytest +from unittest.mock import MagicMock, AsyncMock +import asyncio +from datetime import timedelta +from litellm.proxy.utils import ProxyUpdateSpend + + +@pytest.mark.asyncio +async def test_end_user_transactions_reset(): + # Setup + mock_client = MagicMock() + mock_client.end_user_list_transactons = {"1": 10.0} # Bad log + mock_client.db.tx = AsyncMock(side_effect=Exception("DB Error")) + + # Call function - should raise error + with pytest.raises(Exception): + await ProxyUpdateSpend.update_end_user_spend( + n_retry_times=0, prisma_client=mock_client, proxy_logging_obj=MagicMock() + ) + + # Verify cleanup happened + assert ( + mock_client.end_user_list_transactons == {} + ), "Transactions list should be empty after error" + + +@pytest.mark.asyncio +async def test_spend_logs_cleanup_after_error(): + # Setup test data + mock_client = MagicMock() + mock_client.spend_log_transactions = [ + {"id": 1, "amount": 10.0}, + {"id": 2, "amount": 20.0}, + {"id": 3, "amount": 30.0}, + ] + # Make the DB operation fail + mock_client.db.litellm_spendlogs.create_many = AsyncMock( + side_effect=Exception("DB Error") + ) + + original_logs = mock_client.spend_log_transactions.copy() + + # Call function - should raise error + with pytest.raises(Exception): + await ProxyUpdateSpend.update_spend_logs( + n_retry_times=0, + prisma_client=mock_client, + db_writer_client=None, # Test DB write path + proxy_logging_obj=MagicMock(), + ) + + # Verify the first batch was removed from spend_log_transactions + assert ( + mock_client.spend_log_transactions == original_logs[100:] + ), "Should remove processed logs even after error" + def test_provider_specific_header(): from litellm.proxy.litellm_pre_call_utils import ( add_provider_specific_headers_to_request, diff --git a/tests/proxy_unit_tests/test_user_api_key_auth.py b/tests/proxy_unit_tests/test_user_api_key_auth.py index 9940299622..a428a29c63 100644 --- a/tests/proxy_unit_tests/test_user_api_key_auth.py +++ b/tests/proxy_unit_tests/test_user_api_key_auth.py @@ -862,3 +862,18 @@ async def test_jwt_user_api_key_auth_builder_enforce_rbac(enforce_rbac, monkeypa await _jwt_auth_user_api_key_auth_builder(**args) else: await _jwt_auth_user_api_key_auth_builder(**args) + + +def test_user_api_key_auth_end_user_str(): + from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth + + user_api_key_args = { + "api_key": "sk-1234", + "parent_otel_span": None, + "user_role": LitellmUserRoles.PROXY_ADMIN, + "end_user_id": "1", + "user_id": "default_user_id", + } + + user_api_key_auth = UserAPIKeyAuth(**user_api_key_args) + assert user_api_key_auth.end_user_id == "1"