diff --git a/docs/my-website/docs/proxy/custom_pricing.md b/docs/my-website/docs/proxy/custom_pricing.md new file mode 100644 index 000000000..10ae06667 --- /dev/null +++ b/docs/my-website/docs/proxy/custom_pricing.md @@ -0,0 +1,34 @@ +import Image from '@theme/IdealImage'; + +# Custom Pricing - Sagemaker, etc. + +Use this to register custom pricing (cost per token or cost per second) for models. + +## Quick Start + +Register custom pricing for sagemaker completion + embedding models. + +For cost per second pricing, you **just** need to register `input_cost_per_second`. + +**Step 1: Add pricing to config.yaml** +```yaml +model_list: + - model_name: sagemaker-completion-model + litellm_params: + model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 + input_cost_per_second: 0.000420 + - model_name: sagemaker-embedding-model + litellm_params: + model: sagemaker/berri-benchmarking-gpt-j-6b-fp16 + input_cost_per_second: 0.000420 +``` + +**Step 2: Start proxy** + +```bash +litellm /path/to/config.yaml +``` + +**Step 3: View Spend Logs** + + \ No newline at end of file diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md index 1c7e0631a..e1c89bbc2 100644 --- a/docs/my-website/docs/proxy/virtual_keys.md +++ b/docs/my-website/docs/proxy/virtual_keys.md @@ -440,6 +440,97 @@ general_settings: $ litellm --config /path/to/config.yaml ``` +## Custom /key/generate + +If you need to add custom logic before generating a Proxy API Key (Example Validating `team_id`) + +### 1. Write a custom `custom_generate_key_fn` + + +The input to the custom_generate_key_fn function is a single parameter: `data` [(Type: GenerateKeyRequest)](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/_types.py#L125) + +The output of your `custom_generate_key_fn` should be a dictionary with the following structure +```python +{ + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", +} + +``` + +- decision (Type: bool): A boolean value indicating whether the key generation is allowed (True) or not (False). + +- message (Type: str, Optional): An optional message providing additional information about the decision. This field is included when the decision is False. + + +```python +async def custom_generate_key_fn(data: GenerateKeyRequest)-> dict: + """ + Asynchronous function for generating a key based on the input data. + + Args: + data (GenerateKeyRequest): The input data for key generation. + + Returns: + dict: A dictionary containing the decision and an optional message. + { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } + """ + + # decide if a key should be generated or not + print("using custom auth function!") + data_json = data.json() # type: ignore + + # Unpacking variables + team_id = data_json.get("team_id") + duration = data_json.get("duration") + models = data_json.get("models") + aliases = data_json.get("aliases") + config = data_json.get("config") + spend = data_json.get("spend") + user_id = data_json.get("user_id") + max_parallel_requests = data_json.get("max_parallel_requests") + metadata = data_json.get("metadata") + tpm_limit = data_json.get("tpm_limit") + rpm_limit = data_json.get("rpm_limit") + + if team_id is not None and team_id == "litellm-core-infra@gmail.com": + # only team_id="litellm-core-infra@gmail.com" can make keys + return { + "decision": True, + } + else: + print("Failed custom auth") + return { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } +``` + + +### 2. Pass the filepath (relative to the config.yaml) + +Pass the filepath to the config.yaml + +e.g. if they're both in the same dir - `./config.yaml` and `./custom_auth.py`, this is what it looks like: +```yaml +model_list: + - model_name: "openai-model" + litellm_params: + model: "gpt-3.5-turbo" + +litellm_settings: + drop_params: True + set_verbose: True + +general_settings: + custom_key_generate: custom_auth.custom_generate_key_fn +``` + + + ## [BETA] Dynamo DB diff --git a/docs/my-website/img/spend_logs_table.png b/docs/my-website/img/spend_logs_table.png new file mode 100644 index 000000000..a0f259244 Binary files /dev/null and b/docs/my-website/img/spend_logs_table.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 900e7bc5f..8e20426fd 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -139,6 +139,7 @@ const sidebars = { "items": [ "proxy/call_hooks", "proxy/rules", + "proxy/custom_pricing" ] }, "proxy/deploy", diff --git a/litellm/_logging.py b/litellm/_logging.py index 0bd82a6bd..b1276c045 100644 --- a/litellm/_logging.py +++ b/litellm/_logging.py @@ -12,15 +12,6 @@ formatter = logging.Formatter("\033[92m%(name)s - %(levelname)s\033[0m: %(messag handler.setFormatter(formatter) - -def print_verbose(print_statement): - try: - if set_verbose: - print(print_statement) # noqa - except: - pass - - verbose_proxy_logger = logging.getLogger("LiteLLM Proxy") verbose_router_logger = logging.getLogger("LiteLLM Router") verbose_logger = logging.getLogger("LiteLLM") @@ -29,3 +20,18 @@ verbose_logger = logging.getLogger("LiteLLM") verbose_router_logger.addHandler(handler) verbose_proxy_logger.addHandler(handler) verbose_logger.addHandler(handler) + + +def print_verbose(print_statement): + try: + if set_verbose: + print(print_statement) # noqa + verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug + verbose_router_logger.setLevel( + level=logging.DEBUG + ) # set router logs to debug + verbose_proxy_logger.setLevel( + level=logging.DEBUG + ) # set proxy logs to debug + except: + pass diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 0eb70c86f..f20a2e939 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -629,12 +629,23 @@ class AzureChatCompletion(BaseLLM): client_session = litellm.aclient_session or httpx.AsyncClient( transport=AsyncCustomHTTPTransport(), ) - openai_aclient = AsyncAzureOpenAI( + azure_client = AsyncAzureOpenAI( http_client=client_session, **azure_client_params ) else: - openai_aclient = client - response = await openai_aclient.images.generate(**data, timeout=timeout) + azure_client = client + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"api_key": azure_client.api_key}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + response = await azure_client.images.generate(**data, timeout=timeout) stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -719,7 +730,7 @@ class AzureChatCompletion(BaseLLM): input=prompt, api_key=azure_client.api_key, additional_args={ - "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "headers": {"api_key": azure_client.api_key}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data, diff --git a/litellm/llms/custom_httpx/azure_dall_e_2.py b/litellm/llms/custom_httpx/azure_dall_e_2.py index a62e1d666..f361ede5b 100644 --- a/litellm/llms/custom_httpx/azure_dall_e_2.py +++ b/litellm/llms/custom_httpx/azure_dall_e_2.py @@ -43,7 +43,7 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport): request=request, ) - time.sleep(int(response.headers.get("retry-after")) or 10) + await asyncio.sleep(int(response.headers.get("retry-after") or 10)) response = await super().handle_async_request(request) await response.aread() @@ -95,7 +95,6 @@ class CustomHTTPTransport(httpx.HTTPTransport): request.method = "GET" response = super().handle_request(request) response.read() - timeout_secs: int = 120 start_time = time.time() while response.json()["status"] not in ["succeeded", "failed"]: @@ -112,11 +111,9 @@ class CustomHTTPTransport(httpx.HTTPTransport): content=json.dumps(timeout).encode("utf-8"), request=request, ) - - time.sleep(int(response.headers.get("retry-after")) or 10) + time.sleep(int(response.headers.get("retry-after", None) or 10)) response = super().handle_request(request) response.read() - if response.json()["status"] == "failed": error_data = response.json() return httpx.Response( diff --git a/litellm/main.py b/litellm/main.py index 271c54e51..2d8f2c0c9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -348,6 +348,13 @@ def mock_completion( prompt_tokens=10, completion_tokens=20, total_tokens=30 ) + try: + _, custom_llm_provider, _, _ = litellm.utils.get_llm_provider(model=model) + model_response._hidden_params["custom_llm_provider"] = custom_llm_provider + except: + # dont let setting a hidden param block a mock_respose + pass + return model_response except: @@ -450,6 +457,8 @@ def completion( ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) ### CUSTOM PROMPT TEMPLATE ### initial_prompt_value = kwargs.get("initial_prompt_value", None) roles = kwargs.get("roles", None) @@ -527,6 +536,8 @@ def completion( "tpm", "input_cost_per_token", "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", "hf_model_name", "model_info", "proxy_server_request", @@ -589,6 +600,19 @@ def completion( } } ) + if ( + input_cost_per_second is not None + ): # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second or 0.0 + litellm.register_model( + { + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + } + } + ) ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### custom_prompt_dict = {} # type: ignore if ( @@ -2240,6 +2264,11 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) openai_params = [ "user", "request_timeout", @@ -2288,6 +2317,8 @@ def embedding( "tpm", "input_cost_per_token", "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", "hf_model_name", "proxy_server_request", "model_info", @@ -2313,6 +2344,28 @@ def embedding( custom_llm_provider=custom_llm_provider, **non_default_params, ) + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model( + { + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + } + } + ) + if input_cost_per_second is not None: # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second or 0.0 + litellm.register_model( + { + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + } + } + ) try: response = None logging = litellm_logging_obj @@ -3281,7 +3334,9 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] return response -def stream_chunk_builder(chunks: list, messages: Optional[list] = None): +def stream_chunk_builder( + chunks: list, messages: Optional[list] = None, start_time=None, end_time=None +): model_response = litellm.ModelResponse() # set hidden params from chunk to model_response if model_response is not None and hasattr(model_response, "_hidden_params"): @@ -3456,5 +3511,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None): response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] ) return convert_to_model_response_object( - response_object=response, model_response_object=model_response + response_object=response, + model_response_object=model_response, + start_time=start_time, + end_time=end_time, ) diff --git a/litellm/proxy/example_config_yaml/custom_auth.py b/litellm/proxy/example_config_yaml/custom_auth.py index 416b66682..a764a647a 100644 --- a/litellm/proxy/example_config_yaml/custom_auth.py +++ b/litellm/proxy/example_config_yaml/custom_auth.py @@ -1,4 +1,4 @@ -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from fastapi import Request from dotenv import load_dotenv import os @@ -14,3 +14,40 @@ async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: raise Exception except: raise Exception + + +async def generate_key_fn(data: GenerateKeyRequest): + """ + Asynchronously decides if a key should be generated or not based on the provided data. + + Args: + data (GenerateKeyRequest): The data to be used for decision making. + + Returns: + bool: True if a key should be generated, False otherwise. + """ + # decide if a key should be generated or not + data_json = data.json() # type: ignore + + # Unpacking variables + team_id = data_json.get("team_id") + duration = data_json.get("duration") + models = data_json.get("models") + aliases = data_json.get("aliases") + config = data_json.get("config") + spend = data_json.get("spend") + user_id = data_json.get("user_id") + max_parallel_requests = data_json.get("max_parallel_requests") + metadata = data_json.get("metadata") + tpm_limit = data_json.get("tpm_limit") + rpm_limit = data_json.get("rpm_limit") + + if team_id is not None and len(team_id) > 0: + return { + "decision": True, + } + else: + return { + "decision": True, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 417b4c6f1..29aa3cf4f 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -62,8 +62,9 @@ litellm_settings: # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] -# general_settings: -# master_key: sk-1234 +general_settings: + master_key: sk-1234 + custom_key_generate: custom_auth.generate_key_fn # database_type: "dynamo_db" # database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190 # "billing_mode": "PAY_PER_REQUEST", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c20353610..24caf5b94 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -187,6 +187,7 @@ prisma_client: Optional[PrismaClient] = None custom_db_client: Optional[DBClient] = None user_api_key_cache = DualCache() user_custom_auth = None +user_custom_key_generate = None use_background_health_checks = None use_queue = False health_check_interval = None @@ -584,7 +585,7 @@ async def track_cost_callback( "user_api_key_user_id", None ) - verbose_proxy_logger.debug( + verbose_proxy_logger.info( f"streaming response_cost {response_cost}, for user_id {user_id}" ) if user_api_key and ( @@ -609,7 +610,7 @@ async def track_cost_callback( user_id = user_id or kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) - verbose_proxy_logger.debug( + verbose_proxy_logger.info( f"response_cost {response_cost}, for user_id {user_id}" ) if user_api_key and ( @@ -896,7 +897,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -1074,6 +1075,12 @@ class ProxyConfig: user_custom_auth = get_instance_fn( value=custom_auth, config_file_path=config_file_path ) + + custom_key_generate = general_settings.get("custom_key_generate", None) + if custom_key_generate is not None: + user_custom_key_generate = get_instance_fn( + value=custom_key_generate, config_file_path=config_file_path + ) ## dynamodb database_type = general_settings.get("database_type", None) if database_type is not None and ( @@ -2189,7 +2196,16 @@ async def generate_key_fn( - expires: (datetime) Datetime object for when key expires. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. """ + global user_custom_key_generate verbose_proxy_logger.debug("entered /key/generate") + + if user_custom_key_generate is not None: + result = await user_custom_key_generate(data) + decision = result.get("decision", True) + message = result.get("message", "Authentication Failed - Custom Auth Rule") + if not decision: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message) + data_json = data.json() # type: ignore # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users @@ -2978,7 +2994,7 @@ async def get_routes(): @router.on_event("shutdown") async def shutdown_event(): - global prisma_client, master_key, user_custom_auth + global prisma_client, master_key, user_custom_auth, user_custom_key_generate if prisma_client: verbose_proxy_logger.debug("Disconnecting from Prisma") await prisma_client.disconnect() @@ -2988,7 +3004,7 @@ async def shutdown_event(): def cleanup_router_config_variables(): - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, prisma_client, custom_db_client + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, prisma_client, custom_db_client # Set all variables to None master_key = None @@ -2996,6 +3012,7 @@ def cleanup_router_config_variables(): otel_logging = None user_custom_auth = None user_custom_auth_path = None + user_custom_key_generate = None use_background_health_checks = None health_check_interval = None prisma_client = None diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 9f183644d..c19137d57 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -449,6 +449,7 @@ class PrismaClient: "update": {}, # don't do anything if it already exists }, ) + verbose_proxy_logger.info(f"Data Inserted into Keys Table") return new_verification_token elif table_name == "user": db_data = self.jsonify_object(data=data) @@ -459,6 +460,7 @@ class PrismaClient: "update": {}, # don't do anything if it already exists }, ) + verbose_proxy_logger.info(f"Data Inserted into User Table") return new_user_row elif table_name == "config": """ @@ -483,6 +485,7 @@ class PrismaClient: tasks.append(updated_table_row) await asyncio.gather(*tasks) + verbose_proxy_logger.info(f"Data Inserted into Config Table") elif table_name == "spend": db_data = self.jsonify_object(data=data) new_spend_row = await self.db.litellm_spendlogs.upsert( @@ -492,6 +495,7 @@ class PrismaClient: "update": {}, # don't do anything if it already exists }, ) + verbose_proxy_logger.info(f"Data Inserted into Spend Table") return new_spend_row except Exception as e: diff --git a/litellm/router.py b/litellm/router.py index b15687f67..38ebcc1c9 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -997,6 +997,9 @@ class Router: """ try: kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done response = await self.async_function_with_retries( *args, **kwargs ) @@ -1025,8 +1028,10 @@ class Router: f"Falling back to model_group = {mg}" ) kwargs["model"] = mg - kwargs["metadata"]["model_group"] = mg - response = await self.async_function_with_retries( + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done + response = await self.async_function_with_fallbacks( *args, **kwargs ) return response @@ -1191,6 +1196,9 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: @@ -1214,6 +1222,9 @@ class Router: ## LOGGING kwargs = self.log_retry(kwargs=kwargs, e=original_exception) kwargs["model"] = mg + kwargs.setdefault("metadata", {}).update( + {"model_group": mg} + ) # update model_group used, if fallbacks are done response = self.function_with_fallbacks(*args, **kwargs) return response except Exception as e: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 247ae4676..644b348ec 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1372,16 +1372,21 @@ def test_customprompt_together_ai(): def test_completion_sagemaker(): try: - print("testing sagemaker") litellm.set_verbose = True + print("testing sagemaker") response = completion( model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", messages=messages, temperature=0.2, max_tokens=80, + input_cost_per_second=0.000420, ) # Add any assertions here to check the response print(response) + cost = completion_cost(completion_response=response) + assert ( + cost > 0.0 and cost < 1.0 + ) # should never be > $1 for a single completion call except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index de7dd67b4..565df5b25 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -1,56 +1,58 @@ ### What this tests #### import sys, os, time, inspect, asyncio, traceback import pytest -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion, embedding import litellm from litellm.integrations.custom_logger import CustomLogger + class MyCustomHandler(CustomLogger): complete_streaming_response_in_callback = "" + def __init__(self): - self.success: bool = False # type: ignore - self.failure: bool = False # type: ignore - self.async_success: bool = False # type: ignore + self.success: bool = False # type: ignore + self.failure: bool = False # type: ignore + self.async_success: bool = False # type: ignore self.async_success_embedding: bool = False # type: ignore - self.async_failure: bool = False # type: ignore + self.async_failure: bool = False # type: ignore self.async_failure_embedding: bool = False # type: ignore - self.async_completion_kwargs = None # type: ignore - self.async_embedding_kwargs = None # type: ignore - self.async_embedding_response = None # type: ignore + self.async_completion_kwargs = None # type: ignore + self.async_embedding_kwargs = None # type: ignore + self.async_embedding_response = None # type: ignore - self.async_completion_kwargs_fail = None # type: ignore - self.async_embedding_kwargs_fail = None # type: ignore + self.async_completion_kwargs_fail = None # type: ignore + self.async_embedding_kwargs_fail = None # type: ignore - self.stream_collected_response = None # type: ignore - self.sync_stream_collected_response = None # type: ignore - self.user = None # type: ignore + self.stream_collected_response = None # type: ignore + self.sync_stream_collected_response = None # type: ignore + self.user = None # type: ignore self.data_sent_to_api: dict = {} - def log_pre_api_call(self, model, messages, kwargs): + def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {}) - - def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): print(f"Post-API Call") - + def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") - - def log_success_event(self, kwargs, response_obj, start_time, end_time): + + def log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Success") self.success = True if kwargs.get("stream") == True: self.sync_stream_collected_response = response_obj - - def log_failure_event(self, kwargs, response_obj, start_time, end_time): + def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") self.failure = True - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async success") print(f"received kwargs user: {kwargs['user']}") self.async_success = True @@ -62,24 +64,30 @@ class MyCustomHandler(CustomLogger): self.stream_collected_response = response_obj self.async_completion_kwargs = kwargs self.user = kwargs.get("user", None) - - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Async Failure") self.async_failure = True if kwargs.get("model") == "text-embedding-ada-002": self.async_failure_embedding = True self.async_embedding_kwargs_fail = kwargs - + self.async_completion_kwargs_fail = kwargs + class TmpFunction: complete_streaming_response_in_callback = "" async_success: bool = False + async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time): print(f"ON ASYNC LOGGING") self.async_success = True - print(f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}') - self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response") + print( + f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}' + ) + self.complete_streaming_response_in_callback = kwargs.get( + "complete_streaming_response" + ) def test_async_chat_openai_stream(): @@ -88,29 +96,39 @@ def test_async_chat_openai_stream(): # litellm.set_verbose = True litellm.success_callback = [tmp_function.async_test_logging_fn] complete_streaming_response = "" + async def call_gpt(): nonlocal complete_streaming_response - response = await litellm.acompletion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }], - stream=True) - async for chunk in response: - complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + stream=True, + ) + async for chunk in response: + complete_streaming_response += ( + chunk["choices"][0]["delta"]["content"] or "" + ) print(complete_streaming_response) + asyncio.run(call_gpt()) complete_streaming_response = complete_streaming_response.strip("'") - response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"] + response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][ + "message" + ]["content"] response2 = complete_streaming_response # assert [ord(c) for c in response1] == [ord(c) for c in response2] + print(f"response1: {response1}") + print(f"response2: {response2}") assert response1 == response2 assert tmp_function.async_success == True except Exception as e: print(e) pytest.fail(f"An error occurred - {str(e)}") + + # test_async_chat_openai_stream() + def test_completion_azure_stream_moderation_failure(): try: customHandler = MyCustomHandler() @@ -122,11 +140,11 @@ def test_completion_azure_stream_moderation_failure(): "content": "how do i kill someone", }, ] - try: + try: response = completion( model="azure/chatgpt-v-2", messages=messages, stream=True ) - for chunk in response: + for chunk in response: print(f"chunk: {chunk}") continue except Exception as e: @@ -139,7 +157,7 @@ def test_completion_azure_stream_moderation_failure(): def test_async_custom_handler_stream(): try: - # [PROD Test] - Do not DELETE + # [PROD Test] - Do not DELETE # checks if the model response available in the async + stream callbacks is equal to the received response customHandler2 = MyCustomHandler() litellm.callbacks = [customHandler2] @@ -152,32 +170,37 @@ def test_async_custom_handler_stream(): }, ] complete_streaming_response = "" + async def test_1(): nonlocal complete_streaming_response response = await litellm.acompletion( - model="azure/chatgpt-v-2", - messages=messages, - stream=True + model="azure/chatgpt-v-2", messages=messages, stream=True ) - async for chunk in response: - complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" + async for chunk in response: + complete_streaming_response += ( + chunk["choices"][0]["delta"]["content"] or "" + ) print(complete_streaming_response) - + asyncio.run(test_1()) response_in_success_handler = customHandler2.stream_collected_response - response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"] + response_in_success_handler = response_in_success_handler["choices"][0][ + "message" + ]["content"] print("\n\n") print("response_in_success_handler: ", response_in_success_handler) print("complete_streaming_response: ", complete_streaming_response) assert response_in_success_handler == complete_streaming_response except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_async_custom_handler_stream() def test_azure_completion_stream(): - # [PROD Test] - Do not DELETE + # [PROD Test] - Do not DELETE # test if completion() + sync custom logger get the same complete stream response try: # checks if the model response available in the async + stream callbacks is equal to the received response @@ -194,17 +217,17 @@ def test_azure_completion_stream(): complete_streaming_response = "" response = litellm.completion( - model="azure/chatgpt-v-2", - messages=messages, - stream=True + model="azure/chatgpt-v-2", messages=messages, stream=True ) - for chunk in response: + for chunk in response: complete_streaming_response += chunk["choices"][0]["delta"]["content"] or "" print(complete_streaming_response) - - time.sleep(0.5) # wait 1/2 second before checking callbacks + + time.sleep(0.5) # wait 1/2 second before checking callbacks response_in_success_handler = customHandler2.sync_stream_collected_response - response_in_success_handler = response_in_success_handler["choices"][0]["message"]["content"] + response_in_success_handler = response_in_success_handler["choices"][0][ + "message" + ]["content"] print("\n\n") print("response_in_success_handler: ", response_in_success_handler) print("complete_streaming_response: ", complete_streaming_response) @@ -212,24 +235,32 @@ def test_azure_completion_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") + @pytest.mark.asyncio -async def test_async_custom_handler_completion(): - try: +async def test_async_custom_handler_completion(): + try: customHandler_success = MyCustomHandler() customHandler_failure = MyCustomHandler() # success assert customHandler_success.async_success == False litellm.callbacks = [customHandler_success] response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{ + model="gpt-3.5-turbo", + messages=[ + { "role": "user", "content": "hello from litellm test", - }] - ) + } + ], + ) await asyncio.sleep(1) - assert customHandler_success.async_success == True, "async success is not set to True even after success" - assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo" + assert ( + customHandler_success.async_success == True + ), "async success is not set to True even after success" + assert ( + customHandler_success.async_completion_kwargs.get("model") + == "gpt-3.5-turbo" + ) # failure litellm.callbacks = [customHandler_failure] messages = [ @@ -240,80 +271,119 @@ async def test_async_custom_handler_completion(): }, ] - assert customHandler_failure.async_failure == False - try: + assert customHandler_failure.async_failure == False + try: response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=messages, - api_key="my-bad-key", - ) + model="gpt-3.5-turbo", + messages=messages, + api_key="my-bad-key", + ) except: pass - assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure" - assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo" - assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 + assert ( + customHandler_failure.async_failure == True + ), "async failure is not set to True even after failure" + assert ( + customHandler_failure.async_completion_kwargs_fail.get("model") + == "gpt-3.5-turbo" + ) + assert ( + len( + str(customHandler_failure.async_completion_kwargs_fail.get("exception")) + ) + > 10 + ) # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 litellm.callbacks = [] print("Passed setting async failure") except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + + # asyncio.run(test_async_custom_handler_completion()) + @pytest.mark.asyncio -async def test_async_custom_handler_embedding(): - try: +async def test_async_custom_handler_embedding(): + try: customHandler_embedding = MyCustomHandler() litellm.callbacks = [customHandler_embedding] # success assert customHandler_embedding.async_success_embedding == False response = await litellm.aembedding( - model="text-embedding-ada-002", - input = ["hello world"], - ) + model="text-embedding-ada-002", + input=["hello world"], + ) await asyncio.sleep(1) - assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success" - assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002" - assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2 + assert ( + customHandler_embedding.async_success_embedding == True + ), "async_success_embedding is not set to True even after success" + assert ( + customHandler_embedding.async_embedding_kwargs.get("model") + == "text-embedding-ada-002" + ) + assert ( + customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] + == 2 + ) print("Passed setting async success: Embedding") - # failure + # failure assert customHandler_embedding.async_failure_embedding == False - try: + try: response = await litellm.aembedding( - model="text-embedding-ada-002", - input = ["hello world"], - api_key="my-bad-key", - ) - except: + model="text-embedding-ada-002", + input=["hello world"], + api_key="my-bad-key", + ) + except: pass - assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure" - assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002" - assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 + assert ( + customHandler_embedding.async_failure_embedding == True + ), "async failure embedding is not set to True even after failure" + assert ( + customHandler_embedding.async_embedding_kwargs_fail.get("model") + == "text-embedding-ada-002" + ) + assert ( + len( + str( + customHandler_embedding.async_embedding_kwargs_fail.get("exception") + ) + ) + > 10 + ) # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119 except Exception as e: pytest.fail(f"An exception occurred - {str(e)}") + + # asyncio.run(test_async_custom_handler_embedding()) + @pytest.mark.asyncio -async def test_async_custom_handler_embedding_optional_param(): +async def test_async_custom_handler_embedding_optional_param(): """ - Tests if the openai optional params for embedding - user + encoding_format, + Tests if the openai optional params for embedding - user + encoding_format, are logged """ customHandler_optional_params = MyCustomHandler() litellm.callbacks = [customHandler_optional_params] response = await litellm.aembedding( - model="azure/azure-embedding-model", - input = ["hello world"], - user = "John" - ) - await asyncio.sleep(1) # success callback is async + model="azure/azure-embedding-model", input=["hello world"], user="John" + ) + await asyncio.sleep(1) # success callback is async assert customHandler_optional_params.user == "John" - assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"] + assert ( + customHandler_optional_params.user + == customHandler_optional_params.data_sent_to_api["user"] + ) + # asyncio.run(test_async_custom_handler_embedding_optional_param()) + @pytest.mark.asyncio -async def test_async_custom_handler_embedding_optional_param_bedrock(): +async def test_async_custom_handler_embedding_optional_param_bedrock(): """ - Tests if the openai optional params for embedding - user + encoding_format, + Tests if the openai optional params for embedding - user + encoding_format, are logged but makes sure these are not sent to the non-openai/azure endpoint (raises errors). @@ -323,42 +393,68 @@ async def test_async_custom_handler_embedding_optional_param_bedrock(): customHandler_optional_params = MyCustomHandler() litellm.callbacks = [customHandler_optional_params] response = await litellm.aembedding( - model="bedrock/amazon.titan-embed-text-v1", - input = ["hello world"], - user = "John" - ) - await asyncio.sleep(1) # success callback is async + model="bedrock/amazon.titan-embed-text-v1", input=["hello world"], user="John" + ) + await asyncio.sleep(1) # success callback is async assert customHandler_optional_params.user == "John" assert "user" not in customHandler_optional_params.data_sent_to_api def test_redis_cache_completion_stream(): from litellm import Cache - # Important Test - This tests if we can add to streaming cache, when custom callbacks are set + + # Important Test - This tests if we can add to streaming cache, when custom callbacks are set import random + try: print("\nrunning test_redis_cache_completion_stream") litellm.set_verbose = True - random_number = random.randint(1, 100000) # add a random number to ensure it's always adding / reading from cache - messages = [{"role": "user", "content": f"write a one sentence poem about: {random_number}"}] - litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD']) + random_number = random.randint( + 1, 100000 + ) # add a random number to ensure it's always adding / reading from cache + messages = [ + { + "role": "user", + "content": f"write a one sentence poem about: {random_number}", + } + ] + litellm.cache = Cache( + type="redis", + host=os.environ["REDIS_HOST"], + port=os.environ["REDIS_PORT"], + password=os.environ["REDIS_PASSWORD"], + ) print("test for caching, streaming + completion") - response1 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) + response1 = completion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + stream=True, + ) response_1_content = "" for chunk in response1: print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) - time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur - response2 = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=40, temperature=0.2, stream=True) + time.sleep(0.1) # sleep for 0.1 seconds allow set cache to occur + response2 = completion( + model="gpt-3.5-turbo", + messages=messages, + max_tokens=40, + temperature=0.2, + stream=True, + ) response_2_content = "" for chunk in response2: print(chunk) response_2_content += chunk.choices[0].delta.content or "" print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) - assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + assert ( + response_1_content == response_2_content + ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.success_callback = [] litellm._async_success_callback = [] litellm.cache = None @@ -366,4 +462,6 @@ def test_redis_cache_completion_stream(): print(e) litellm.success_callback = [] raise e -# test_redis_cache_completion_stream() \ No newline at end of file + + +# test_redis_cache_completion_stream() diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d1f0ee699..630b41d72 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -10,7 +10,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm -from litellm import embedding, completion +from litellm import embedding, completion, completion_cost litellm.set_verbose = False @@ -341,8 +341,30 @@ def test_sagemaker_embeddings(): response = litellm.embedding( model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"], + input_cost_per_second=0.000420, ) print(f"response: {response}") + cost = completion_cost(completion_response=response) + assert ( + cost > 0.0 and cost < 1.0 + ) # should never be > $1 for a single embedding call + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_sagemaker_aembeddings(): + try: + response = await litellm.aembedding( + model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", + input=["good morning from litellm", "this is another item"], + input_cost_per_second=0.000420, + ) + print(f"response: {response}") + cost = completion_cost(completion_response=response) + assert ( + cost > 0.0 and cost < 1.0 + ) # should never be > $1 for a single embedding call except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 484c9d5a5..2447448ff 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -35,6 +35,7 @@ import pytest, logging, asyncio import litellm, asyncio from litellm.proxy.proxy_server import ( new_user, + generate_key_fn, user_api_key_auth, user_update, delete_key_fn, @@ -53,6 +54,7 @@ from litellm.proxy._types import ( DynamoDBArgs, DeleteKeyRequest, UpdateKeyRequest, + GenerateKeyRequest, ) from litellm.proxy.utils import DBClient from starlette.datastructures import URL @@ -597,6 +599,85 @@ def test_generate_and_update_key(prisma_client): print("Got Exception", e) print(e.detail) pytest.fail(f"An exception occurred - {str(e)}") + +def test_key_generate_with_custom_auth(prisma_client): + # custom - generate key function + async def custom_generate_key_fn(data: GenerateKeyRequest) -> dict: + """ + Asynchronous function for generating a key based on the input data. + + Args: + data (GenerateKeyRequest): The input data for key generation. + + Returns: + dict: A dictionary containing the decision and an optional message. + { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } + """ + + # decide if a key should be generated or not + print("using custom auth function!") + data_json = data.json() # type: ignore + + # Unpacking variables + team_id = data_json.get("team_id") + duration = data_json.get("duration") + models = data_json.get("models") + aliases = data_json.get("aliases") + config = data_json.get("config") + spend = data_json.get("spend") + user_id = data_json.get("user_id") + max_parallel_requests = data_json.get("max_parallel_requests") + metadata = data_json.get("metadata") + tpm_limit = data_json.get("tpm_limit") + rpm_limit = data_json.get("rpm_limit") + + if team_id is not None and team_id == "litellm-core-infra@gmail.com": + # only team_id="litellm-core-infra@gmail.com" can make keys + return { + "decision": True, + } + else: + print("Failed custom auth") + return { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. No team id provided.", + } + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr( + litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn + ) + try: + request = GenerateKeyRequest() + key = await generate_key_fn(request) + pytest.fail(f"Expected an exception. Got {key}") + except Exception as e: + # this should fail + print("Got Exception", e) + print(e.detail) + print("First request failed!. This is expected") + assert ( + "This violates LiteLLM Proxy Rules. No team id provided." + in e.detail + ) + + request_2 = GenerateKeyRequest( + team_id="litellm-core-infra@gmail.com", + ) + + key = await generate_key_fn(request_2) + print(key) + generated_key = key.key + + asyncio.run(test()) + except Exception as e: + print("Got Exception", e) + print(e.detail) + pytest.fail(f"An exception occurred - {str(e)}") def test_call_with_key_over_budget(prisma_client): diff --git a/litellm/tests/test_router_fallbacks.py b/litellm/tests/test_router_fallbacks.py index 29bc0d7bf..5d17d36c9 100644 --- a/litellm/tests/test_router_fallbacks.py +++ b/litellm/tests/test_router_fallbacks.py @@ -716,7 +716,7 @@ def test_usage_based_routing_fallbacks(): # Constants for TPM and RPM allocation AZURE_FAST_TPM = 3 AZURE_BASIC_TPM = 4 - OPENAI_TPM = 2000 + OPENAI_TPM = 400 ANTHROPIC_TPM = 100000 def get_azure_params(deployment_name: str): @@ -775,6 +775,7 @@ def test_usage_based_routing_fallbacks(): model_list=model_list, fallbacks=fallbacks_list, set_verbose=True, + debug_level="DEBUG", routing_strategy="usage-based-routing", redis_host=os.environ["REDIS_HOST"], redis_port=os.environ["REDIS_PORT"], @@ -783,17 +784,32 @@ def test_usage_based_routing_fallbacks(): messages = [ {"content": "Tell me a joke.", "role": "user"}, ] - response = router.completion( - model="azure/gpt-4-fast", messages=messages, timeout=5 + model="azure/gpt-4-fast", + messages=messages, + timeout=5, + mock_response="very nice to meet you", ) print("response: ", response) print("response._hidden_params: ", response._hidden_params) - # in this test, we expect azure/gpt-4 fast to fail, then azure-gpt-4 basic to fail and then openai-gpt-4 to pass # the token count of this message is > AZURE_FAST_TPM, > AZURE_BASIC_TPM assert response._hidden_params["custom_llm_provider"] == "openai" + # now make 100 mock requests to OpenAI - expect it to fallback to anthropic-claude-instant-1.2 + for i in range(20): + response = router.completion( + model="azure/gpt-4-fast", + messages=messages, + timeout=5, + mock_response="very nice to meet you", + ) + print("response: ", response) + print("response._hidden_params: ", response._hidden_params) + if i == 19: + # by the 19th call we should have hit TPM LIMIT for OpenAI, it should fallback to anthropic-claude-instant-1.2 + assert response._hidden_params["custom_llm_provider"] == "anthropic" + except Exception as e: pytest.fail(f"An exception occurred {e}") diff --git a/litellm/utils.py b/litellm/utils.py index c690f7cc4..468e67136 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -765,6 +765,7 @@ class Logging: self.litellm_call_id = litellm_call_id self.function_id = function_id self.streaming_chunks = [] # for generating complete stream response + self.sync_streaming_chunks = [] # for generating complete stream response self.model_call_details = {} def update_environment_variables( @@ -828,7 +829,7 @@ class Logging: [f"-H '{k}: {v}'" for k, v in masked_headers.items()] ) - print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") + verbose_logger.debug(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}") curl_command = "\n\nPOST Request Sent from LiteLLM:\n" curl_command += "curl -X POST \\\n" @@ -994,13 +995,10 @@ class Logging: self.model_call_details["log_event_type"] = "post_api_call" # User Logging -> if you pass in a custom logging function - print_verbose( + verbose_logger.debug( f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n" ) - print_verbose( - f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" - ) - print_verbose( + verbose_logger.debug( f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}" ) if self.logger_fn and callable(self.logger_fn): @@ -1094,20 +1092,20 @@ class Logging: if ( result.choices[0].finish_reason is not None ): # if it's the last chunk - self.streaming_chunks.append(result) - # print_verbose(f"final set of received chunks: {self.streaming_chunks}") + self.sync_streaming_chunks.append(result) + # print_verbose(f"final set of received chunks: {self.sync_streaming_chunks}") try: complete_streaming_response = litellm.stream_chunk_builder( - self.streaming_chunks, + self.sync_streaming_chunks, messages=self.model_call_details.get("messages", None), ) except: complete_streaming_response = None else: - self.streaming_chunks.append(result) + self.sync_streaming_chunks.append(result) if complete_streaming_response: - verbose_logger.info( + verbose_logger.debug( f"Logging Details LiteLLM-Success Call streaming complete" ) self.model_call_details[ @@ -1307,7 +1305,9 @@ class Logging: ) == False ): # custom logger class - print_verbose(f"success callbacks: Running Custom Logger Class") + verbose_logger.info( + f"success callbacks: Running SYNC Custom Logger Class" + ) if self.stream and complete_streaming_response is None: callback.log_stream_event( kwargs=self.model_call_details, @@ -1329,7 +1329,17 @@ class Logging: start_time=start_time, end_time=end_time, ) - if callable(callback): # custom logger functions + elif ( + callable(callback) == True + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + ): # custom logger functions print_verbose( f"success callbacks: Running Custom Callback Function" ) @@ -1364,6 +1374,9 @@ class Logging: Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ print_verbose(f"Async success callbacks: {litellm._async_success_callback}") + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit + ) ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None if self.stream: @@ -1374,6 +1387,8 @@ class Logging: complete_streaming_response = litellm.stream_chunk_builder( self.streaming_chunks, messages=self.model_call_details.get("messages", None), + start_time=start_time, + end_time=end_time, ) except Exception as e: print_verbose( @@ -1387,9 +1402,7 @@ class Logging: self.model_call_details[ "complete_streaming_response" ] = complete_streaming_response - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit - ) + for callback in litellm._async_success_callback: try: if callback == "cache" and litellm.cache is not None: @@ -1436,7 +1449,6 @@ class Logging: end_time=end_time, ) if callable(callback): # custom logger functions - print_verbose(f"Async success callbacks: async_log_event") await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, @@ -2134,7 +2146,7 @@ def client(original_function): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated - print_verbose(f"Wrapper: Completed Call, calling success_handler") + verbose_logger.info(f"Wrapper: Completed Call, calling success_handler") threading.Thread( target=logging_obj.success_handler, args=(result, start_time, end_time) ).start() @@ -2373,7 +2385,9 @@ def client(original_function): result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( "id", None ) - if isinstance(result, ModelResponse): + if isinstance(result, ModelResponse) or isinstance( + result, EmbeddingResponse + ): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai @@ -2806,7 +2820,11 @@ def token_counter( def cost_per_token( - model="", prompt_tokens=0, completion_tokens=0, custom_llm_provider=None + model="", + prompt_tokens=0, + completion_tokens=0, + response_time_ms=None, + custom_llm_provider=None, ): """ Calculates the cost per token for a given model, prompt tokens, and completion tokens. @@ -2828,14 +2846,35 @@ def cost_per_token( else: model_with_provider = model # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models - print_verbose(f"Looking up model={model} in model_cost_map") + verbose_logger.debug(f"Looking up model={model} in model_cost_map") if model in model_cost_ref: - prompt_tokens_cost_usd_dollar = ( - model_cost_ref[model]["input_cost_per_token"] * prompt_tokens - ) - completion_tokens_cost_usd_dollar = ( - model_cost_ref[model]["output_cost_per_token"] * completion_tokens + verbose_logger.debug(f"Success: model={model} in model_cost_map") + if ( + model_cost_ref[model].get("input_cost_per_token", None) is not None + and model_cost_ref[model].get("output_cost_per_token", None) is not None + ): + ## COST PER TOKEN ## + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + ) + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model]["output_cost_per_token"] * completion_tokens + ) + elif ( + model_cost_ref[model].get("input_cost_per_second", None) is not None + and response_time_ms is not None + ): + verbose_logger.debug( + f"For model={model} - input_cost_per_second: {model_cost_ref[model].get('input_cost_per_second')}; response time: {response_time_ms}" + ) + ## COST PER SECOND ## + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000 + ) + completion_tokens_cost_usd_dollar = 0.0 + verbose_logger.debug( + f"Returned custom cost for model={model} - prompt_tokens_cost_usd_dollar: {prompt_tokens_cost_usd_dollar}, completion_tokens_cost_usd_dollar: {completion_tokens_cost_usd_dollar}" ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model_with_provider in model_cost_ref: @@ -2938,6 +2977,10 @@ def completion_cost( completion_tokens = completion_response.get("usage", {}).get( "completion_tokens", 0 ) + total_time = completion_response.get("_response_ms", 0) + verbose_logger.debug( + f"completion_response response ms: {completion_response.get('_response_ms')} " + ) model = ( model or completion_response["model"] ) # check if user passed an override for model, if it's none check completion_response['model'] @@ -2975,6 +3018,7 @@ def completion_cost( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, custom_llm_provider=custom_llm_provider, + response_time_ms=total_time, ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar except Exception as e: @@ -3005,9 +3049,8 @@ def register_model(model_cost: Union[str, dict]): for key, value in loaded_model_cost.items(): ## override / add new keys to the existing model cost dictionary - if key in litellm.model_cost: - for k, v in loaded_model_cost[key].items(): - litellm.model_cost[key][k] = v + litellm.model_cost.setdefault(key, {}).update(value) + verbose_logger.debug(f"{key} added to model cost map") # add new model names to provider lists if value.get("litellm_provider") == "openai": if key not in litellm.open_ai_chat_completion_models: @@ -3300,11 +3343,13 @@ def get_optional_params( ) def _check_valid_arg(supported_params): - print_verbose( + verbose_logger.debug( f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" ) - print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}") - print_verbose( + verbose_logger.debug( + f"\nLiteLLM: Params passed to completion() {passed_params}" + ) + verbose_logger.debug( f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}" ) unsupported_params = {} @@ -5150,6 +5195,8 @@ def convert_to_model_response_object( "completion", "embedding", "image_generation" ] = "completion", stream=False, + start_time=None, + end_time=None, ): try: if response_type == "completion" and ( @@ -5203,6 +5250,12 @@ def convert_to_model_response_object( if "model" in response_object: model_response_object.model = response_object["model"] + + if start_time is not None and end_time is not None: + model_response_object._response_ms = ( # type: ignore + end_time - start_time + ).total_seconds() * 1000 + return model_response_object elif response_type == "embedding" and ( model_response_object is None @@ -5227,6 +5280,11 @@ def convert_to_model_response_object( model_response_object.usage.prompt_tokens = response_object["usage"].get("prompt_tokens", 0) # type: ignore model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore + if start_time is not None and end_time is not None: + model_response_object._response_ms = ( # type: ignore + end_time - start_time + ).total_seconds() * 1000 # return response latency in ms like openai + return model_response_object elif response_type == "image_generation" and ( model_response_object is None diff --git a/pyproject.toml b/pyproject.toml index d8638bd4b..de6107b67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.18.8" +version = "1.18.9" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -61,7 +61,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.18.8" +version = "1.18.9" version_files = [ "pyproject.toml:^version" ]