diff --git a/docs/my-website/docs/proxy/custom_pricing.md b/docs/my-website/docs/proxy/custom_pricing.md
index 10ae06667..8eeaa77ef 100644
--- a/docs/my-website/docs/proxy/custom_pricing.md
+++ b/docs/my-website/docs/proxy/custom_pricing.md
@@ -2,14 +2,48 @@ import Image from '@theme/IdealImage';
# Custom Pricing - Sagemaker, etc.
-Use this to register custom pricing (cost per token or cost per second) for models.
+Use this to register custom pricing for models.
+
+There's 2 ways to track cost:
+- cost per token
+- cost per second
+
+By default, the response cost is accessible in the logging object via `kwargs["response_cost"]` on success (sync + async). [**Learn More**](../observability/custom_callback.md)
## Quick Start
-Register custom pricing for sagemaker completion + embedding models.
+Register custom pricing for sagemaker completion model.
For cost per second pricing, you **just** need to register `input_cost_per_second`.
+```python
+# !pip install boto3
+from litellm import completion, completion_cost
+
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+def test_completion_sagemaker():
+ try:
+ print("testing sagemaker")
+ response = completion(
+ model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ input_cost_per_second=0.000420,
+ )
+ # Add any assertions here to check the response
+ print(response)
+ cost = completion_cost(completion_response=response)
+ print(cost)
+ except Exception as e:
+ raise Exception(f"Error occurred: {e}")
+
+```
+
+### Usage with OpenAI Proxy Server
+
**Step 1: Add pricing to config.yaml**
```yaml
model_list:
@@ -31,4 +65,44 @@ litellm /path/to/config.yaml
**Step 3: View Spend Logs**
-
\ No newline at end of file
+
+
+## Cost Per Token
+
+```python
+# !pip install boto3
+from litellm import completion, completion_cost
+
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+def test_completion_sagemaker():
+ try:
+ print("testing sagemaker")
+ response = completion(
+ model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ input_cost_per_token=0.005,
+ output_cost_per_token=1,
+ )
+ # Add any assertions here to check the response
+ print(response)
+ cost = completion_cost(completion_response=response)
+ print(cost)
+ except Exception as e:
+ raise Exception(f"Error occurred: {e}")
+
+```
+
+### Usage with OpenAI Proxy Server
+
+```yaml
+model_list:
+ - model_name: sagemaker-completion-model
+ litellm_params:
+ model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
+ input_cost_per_token: 0.000420 # 👈 key change
+ output_cost_per_token: 0.000420 # 👈 key change
+```
\ No newline at end of file
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 8e20426fd..6a5033e98 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -139,13 +139,13 @@ const sidebars = {
"items": [
"proxy/call_hooks",
"proxy/rules",
- "proxy/custom_pricing"
]
},
"proxy/deploy",
"proxy/cli",
]
},
+ "proxy/custom_pricing",
"routing",
"rules",
"set_keys",
diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py
index 9285bf6f5..0bf201284 100644
--- a/litellm/llms/openai.py
+++ b/litellm/llms/openai.py
@@ -706,15 +706,16 @@ class OpenAIChatCompletion(BaseLLM):
## COMPLETION CALL
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
+ response = response.model_dump() # type: ignore
## LOGGING
logging_obj.post_call(
- input=input,
+ input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
# return response
- return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore
+ return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
raise e
diff --git a/litellm/main.py b/litellm/main.py
index 2d8f2c0c9..9c09085b1 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -83,6 +83,7 @@ from litellm.utils import (
TextCompletionResponse,
TextChoices,
EmbeddingResponse,
+ ImageResponse,
read_config_args,
Choices,
Message,
@@ -2987,6 +2988,7 @@ def image_generation(
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
+ model_response._hidden_params["model"] = model
openai_params = [
"user",
"request_timeout",
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 8ae3ee7d3..dfb4b70f8 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -587,6 +587,10 @@ async def track_cost_callback(
start_time=start_time,
end_time=end_time,
)
+ else:
+ raise Exception(
+ f"Model={kwargs['model']} not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
+ )
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py
index e4bd9e2c1..532fac5d7 100644
--- a/litellm/tests/test_custom_callback_input.py
+++ b/litellm/tests/test_custom_callback_input.py
@@ -74,6 +74,7 @@ class CompletionCustomHandler(
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
+ print(f"kwargs: {kwargs}")
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
@@ -149,7 +150,14 @@ class CompletionCustomHandler(
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
- assert isinstance(response_obj, litellm.ModelResponse)
+ assert isinstance(
+ response_obj,
+ Union[
+ litellm.ModelResponse,
+ litellm.EmbeddingResponse,
+ litellm.ImageResponse,
+ ],
+ )
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
@@ -177,6 +185,7 @@ class CompletionCustomHandler(
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
+ print(f"kwargs: {kwargs}")
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
@@ -766,6 +775,52 @@ async def test_async_embedding_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success
-# asyncio.run(
-# test_async_embedding_azure_caching()
-# )
+# Image Generation
+
+
+# ## Test OpenAI + Sync
+# def test_image_generation_openai():
+# try:
+# customHandler_success = CompletionCustomHandler()
+# customHandler_failure = CompletionCustomHandler()
+# litellm.callbacks = [customHandler_success]
+
+# litellm.set_verbose = True
+
+# response = litellm.image_generation(
+# prompt="A cute baby sea otter", model="dall-e-3"
+# )
+
+# print(f"response: {response}")
+# assert len(response.data) > 0
+
+# print(f"customHandler_success.errors: {customHandler_success.errors}")
+# print(f"customHandler_success.states: {customHandler_success.states}")
+# assert len(customHandler_success.errors) == 0
+# assert len(customHandler_success.states) == 3 # pre, post, success
+# # test failure callback
+# litellm.callbacks = [customHandler_failure]
+# try:
+# response = litellm.image_generation(
+# prompt="A cute baby sea otter", model="dall-e-4"
+# )
+# except:
+# pass
+# print(f"customHandler_failure.errors: {customHandler_failure.errors}")
+# print(f"customHandler_failure.states: {customHandler_failure.states}")
+# assert len(customHandler_failure.errors) == 0
+# assert len(customHandler_failure.states) == 3 # pre, post, failure
+# except litellm.RateLimitError as e:
+# pass
+# except litellm.ContentPolicyViolationError:
+# pass # OpenAI randomly raises these errors - skip when they occur
+# except Exception as e:
+# pytest.fail(f"An exception occurred - {str(e)}")
+
+
+# test_image_generation_openai()
+## Test OpenAI + Async
+
+## Test Azure + Sync
+
+## Test Azure + Async
diff --git a/litellm/utils.py b/litellm/utils.py
index 28b514d40..c67b6d39f 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -1064,13 +1064,21 @@ class Logging:
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
- if result is not None and (
- isinstance(result, ModelResponse)
- or isinstance(result, EmbeddingResponse)
+ ## if model in model cost map - log the response cost
+ ## else set cost to None
+ if (
+ result is not None
+ and (
+ isinstance(result, ModelResponse)
+ or isinstance(result, EmbeddingResponse)
+ )
+ and result.model in litellm.model_cost
):
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=result,
)
+ else: # streaming chunks + image gen.
+ self.model_call_details["response_cost"] = None
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
@@ -1084,7 +1092,7 @@ class Logging:
return start_time, end_time, result
except Exception as e:
- print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
+ raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index ac3637f27..7e5f66990 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -906,6 +906,14 @@
"litellm_provider": "bedrock",
"mode": "chat"
},
+ "amazon.titan-embed-text-v1": {
+ "max_tokens": 8192,
+ "output_vector_size": 1536,
+ "input_cost_per_token": 0.0000001,
+ "output_cost_per_token": 0.0,
+ "litellm_provider": "bedrock",
+ "mode": "embedding"
+ },
"anthropic.claude-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,