diff --git a/docs/my-website/docs/proxy/custom_pricing.md b/docs/my-website/docs/proxy/custom_pricing.md index 10ae066678..8eeaa77ef4 100644 --- a/docs/my-website/docs/proxy/custom_pricing.md +++ b/docs/my-website/docs/proxy/custom_pricing.md @@ -2,14 +2,48 @@ import Image from '@theme/IdealImage'; # Custom Pricing - Sagemaker, etc. -Use this to register custom pricing (cost per token or cost per second) for models. +Use this to register custom pricing for models. + +There's 2 ways to track cost: +- cost per token +- cost per second + +By default, the response cost is accessible in the logging object via `kwargs["response_cost"]` on success (sync + async). [**Learn More**](../observability/custom_callback.md) ## Quick Start -Register custom pricing for sagemaker completion + embedding models. +Register custom pricing for sagemaker completion model. For cost per second pricing, you **just** need to register `input_cost_per_second`. +```python +# !pip install boto3 +from litellm import completion, completion_cost + +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +def test_completion_sagemaker(): + try: + print("testing sagemaker") + response = completion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + input_cost_per_second=0.000420, + ) + # Add any assertions here to check the response + print(response) + cost = completion_cost(completion_response=response) + print(cost) + except Exception as e: + raise Exception(f"Error occurred: {e}") + +``` + +### Usage with OpenAI Proxy Server + **Step 1: Add pricing to config.yaml** ```yaml model_list: @@ -31,4 +65,44 @@ litellm /path/to/config.yaml **Step 3: View Spend Logs** - \ No newline at end of file + + +## Cost Per Token + +```python +# !pip install boto3 +from litellm import completion, completion_cost + +os.environ["AWS_ACCESS_KEY_ID"] = "" +os.environ["AWS_SECRET_ACCESS_KEY"] = "" +os.environ["AWS_REGION_NAME"] = "" + + +def test_completion_sagemaker(): + try: + print("testing sagemaker") + response = completion( + model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + input_cost_per_token=0.005, + output_cost_per_token=1, + ) + # Add any assertions here to check the response + print(response) + cost = completion_cost(completion_response=response) + print(cost) + except Exception as e: + raise Exception(f"Error occurred: {e}") + +``` + +### Usage with OpenAI Proxy Server + +```yaml +model_list: + - model_name: sagemaker-completion-model + litellm_params: + model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 + input_cost_per_token: 0.000420 # 👈 key change + output_cost_per_token: 0.000420 # 👈 key change +``` \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 8e20426fde..6a5033e98b 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -139,13 +139,13 @@ const sidebars = { "items": [ "proxy/call_hooks", "proxy/rules", - "proxy/custom_pricing" ] }, "proxy/deploy", "proxy/cli", ] }, + "proxy/custom_pricing", "routing", "rules", "set_keys", diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9285bf6f55..0bf201284a 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -706,15 +706,16 @@ class OpenAIChatCompletion(BaseLLM): ## COMPLETION CALL response = openai_client.images.generate(**data, timeout=timeout) # type: ignore + response = response.model_dump() # type: ignore ## LOGGING logging_obj.post_call( - input=input, + input=prompt, api_key=api_key, additional_args={"complete_input_dict": data}, original_response=response, ) # return response - return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore + return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore except OpenAIError as e: exception_mapping_worked = True raise e diff --git a/litellm/main.py b/litellm/main.py index 2d8f2c0c91..9c09085b13 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -83,6 +83,7 @@ from litellm.utils import ( TextCompletionResponse, TextChoices, EmbeddingResponse, + ImageResponse, read_config_args, Choices, Message, @@ -2987,6 +2988,7 @@ def image_generation( else: model = "dall-e-2" custom_llm_provider = "openai" # default to dall-e-2 on openai + model_response._hidden_params["model"] = model openai_params = [ "user", "request_timeout", diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 24caf5b94f..78e756a2a6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -570,13 +570,8 @@ async def track_cost_callback( litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request") or {} user_id = proxy_server_request.get("body", {}).get("user", None) - if "complete_streaming_response" in kwargs: - # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost - completion_response = kwargs["complete_streaming_response"] - response_cost = litellm.completion_cost( - completion_response=completion_response - ) - + if "response_cost" in kwargs: + response_cost = kwargs["response_cost"] user_api_key = kwargs["litellm_params"]["metadata"].get( "user_api_key", None ) @@ -585,31 +580,6 @@ async def track_cost_callback( "user_api_key_user_id", None ) - verbose_proxy_logger.info( - f"streaming response_cost {response_cost}, for user_id {user_id}" - ) - if user_api_key and ( - prisma_client is not None or custom_db_client is not None - ): - await update_database( - token=user_api_key, - response_cost=response_cost, - user_id=user_id, - kwargs=kwargs, - completion_response=completion_response, - start_time=start_time, - end_time=end_time, - ) - elif kwargs["stream"] == False: # for non streaming responses - response_cost = litellm.completion_cost( - completion_response=completion_response - ) - user_api_key = kwargs["litellm_params"]["metadata"].get( - "user_api_key", None - ) - user_id = user_id or kwargs["litellm_params"]["metadata"].get( - "user_api_key_user_id", None - ) verbose_proxy_logger.info( f"response_cost {response_cost}, for user_id {user_id}" ) @@ -625,6 +595,10 @@ async def track_cost_callback( start_time=start_time, end_time=end_time, ) + else: + raise Exception( + f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) except Exception as e: verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 0fb69b6451..556628d828 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -74,6 +74,7 @@ class CompletionCustomHandler( def log_post_api_call(self, kwargs, response_obj, start_time, end_time): try: + print(f"kwargs: {kwargs}") self.states.append("post_api_call") ## START TIME assert isinstance(start_time, datetime) @@ -149,7 +150,14 @@ class CompletionCustomHandler( ## END TIME assert isinstance(end_time, datetime) ## RESPONSE OBJECT - assert isinstance(response_obj, litellm.ModelResponse) + assert isinstance( + response_obj, + ( + litellm.ModelResponse, + litellm.EmbeddingResponse, + litellm.ImageResponse, + ), + ) ## KWARGS assert isinstance(kwargs["model"], str) assert isinstance(kwargs["messages"], list) and isinstance( @@ -170,12 +178,14 @@ class CompletionCustomHandler( ) assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str) + assert isinstance(kwargs["response_cost"], (float, type(None))) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) def log_failure_event(self, kwargs, response_obj, start_time, end_time): try: + print(f"kwargs: {kwargs}") self.states.append("sync_failure") ## START TIME assert isinstance(start_time, datetime) @@ -262,6 +272,7 @@ class CompletionCustomHandler( assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str) assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool) + assert isinstance(kwargs["response_cost"], (float, type(None))) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) @@ -545,8 +556,9 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) -# Text Completion - +# Text Completion + + ## Test OpenAI text completion + Async @pytest.mark.asyncio async def test_async_text_completion_openai_stream(): @@ -585,6 +597,7 @@ async def test_async_text_completion_openai_stream(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") + # EMBEDDING ## Test OpenAI + Async @pytest.mark.asyncio @@ -762,6 +775,52 @@ async def test_async_embedding_azure_caching(): assert len(customHandler_caching.states) == 4 # pre, post, success, success -# asyncio.run( -# test_async_embedding_azure_caching() -# ) +# Image Generation + + +# ## Test OpenAI + Sync +# def test_image_generation_openai(): +# try: +# customHandler_success = CompletionCustomHandler() +# customHandler_failure = CompletionCustomHandler() +# litellm.callbacks = [customHandler_success] + +# litellm.set_verbose = True + +# response = litellm.image_generation( +# prompt="A cute baby sea otter", model="dall-e-3" +# ) + +# print(f"response: {response}") +# assert len(response.data) > 0 + +# print(f"customHandler_success.errors: {customHandler_success.errors}") +# print(f"customHandler_success.states: {customHandler_success.states}") +# assert len(customHandler_success.errors) == 0 +# assert len(customHandler_success.states) == 3 # pre, post, success +# # test failure callback +# litellm.callbacks = [customHandler_failure] +# try: +# response = litellm.image_generation( +# prompt="A cute baby sea otter", model="dall-e-4" +# ) +# except: +# pass +# print(f"customHandler_failure.errors: {customHandler_failure.errors}") +# print(f"customHandler_failure.states: {customHandler_failure.states}") +# assert len(customHandler_failure.errors) == 0 +# assert len(customHandler_failure.states) == 3 # pre, post, failure +# except litellm.RateLimitError as e: +# pass +# except litellm.ContentPolicyViolationError: +# pass # OpenAI randomly raises these errors - skip when they occur +# except Exception as e: +# pytest.fail(f"An exception occurred - {str(e)}") + + +# test_image_generation_openai() +## Test OpenAI + Async + +## Test Azure + Sync + +## Test Azure + Async diff --git a/litellm/tests/test_key_generate_dynamodb.py b/litellm/tests/test_key_generate_dynamodb.py index a0772f87a3..3c706b663b 100644 --- a/litellm/tests/test_key_generate_dynamodb.py +++ b/litellm/tests/test_key_generate_dynamodb.py @@ -184,9 +184,11 @@ def test_call_with_user_over_budget(custom_db_client): # 5. Make a call with a key over budget, expect to fail setattr(litellm.proxy.proxy_server, "custom_db_client", custom_db_client) setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") - from litellm._logging import verbose_proxy_logger + from litellm._logging import verbose_proxy_logger, verbose_logger import logging + litellm.set_verbose = True + verbose_logger.setLevel(logging.DEBUG) verbose_proxy_logger.setLevel(logging.DEBUG) try: @@ -234,6 +236,7 @@ def test_call_with_user_over_budget(custom_db_client): "user_api_key_user_id": user_id, } }, + "response_cost": 0.00002, }, completion_response=resp, ) @@ -306,6 +309,7 @@ def test_call_with_user_over_budget_stream(custom_db_client): "user_api_key_user_id": user_id, } }, + "response_cost": 0.00002, }, completion_response=ModelResponse(), ) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 4734e10302..7b19449508 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -260,6 +260,7 @@ def test_call_with_user_over_budget(prisma_client): "user_api_key_user_id": user_id, } }, + "response_cost": 0.00002, }, completion_response=resp, start_time=datetime.now(), @@ -335,6 +336,7 @@ def test_call_with_user_over_budget_stream(prisma_client): "user_api_key_user_id": user_id, } }, + "response_cost": 0.00002, }, completion_response=ModelResponse(), start_time=datetime.now(), diff --git a/litellm/utils.py b/litellm/utils.py index afa2830c93..8ee6691b66 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1064,6 +1064,27 @@ class Logging: self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit + ## if model in model cost map - log the response cost + ## else set cost to None + verbose_logger.debug(f"Model={self.model}; result={result}") + if result is not None and ( + isinstance(result, ModelResponse) + or isinstance(result, EmbeddingResponse) + ): + try: + self.model_call_details["response_cost"] = litellm.completion_cost( + completion_response=result, + ) + verbose_logger.debug( + f"Model={self.model}; cost={self.model_call_details['response_cost']}" + ) + except litellm.NotFoundError as e: + verbose_logger.debug( + f"Model={self.model} not found in completion cost map." + ) + self.model_call_details["response_cost"] = None + else: # streaming chunks + image gen. + self.model_call_details["response_cost"] = None if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds() @@ -1077,7 +1098,7 @@ class Logging: return start_time, end_time, result except Exception as e: - print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") + raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") def success_handler( self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index ac3637f275..7e5f669909 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -906,6 +906,14 @@ "litellm_provider": "bedrock", "mode": "chat" }, + "amazon.titan-embed-text-v1": { + "max_tokens": 8192, + "output_vector_size": 1536, + "input_cost_per_token": 0.0000001, + "output_cost_per_token": 0.0, + "litellm_provider": "bedrock", + "mode": "embedding" + }, "anthropic.claude-v1": { "max_tokens": 100000, "max_output_tokens": 8191,