fix(openai.py): fix linting issue

This commit is contained in:
Krrish Dholakia 2024-01-22 18:20:15 -08:00
parent e917d0eee6
commit 3e8c8ef507
8 changed files with 166 additions and 14 deletions

View file

@ -2,14 +2,48 @@ import Image from '@theme/IdealImage';
# Custom Pricing - Sagemaker, etc. # Custom Pricing - Sagemaker, etc.
Use this to register custom pricing (cost per token or cost per second) for models. Use this to register custom pricing for models.
There's 2 ways to track cost:
- cost per token
- cost per second
By default, the response cost is accessible in the logging object via `kwargs["response_cost"]` on success (sync + async). [**Learn More**](../observability/custom_callback.md)
## Quick Start ## Quick Start
Register custom pricing for sagemaker completion + embedding models. Register custom pricing for sagemaker completion model.
For cost per second pricing, you **just** need to register `input_cost_per_second`. For cost per second pricing, you **just** need to register `input_cost_per_second`.
```python
# !pip install boto3
from litellm import completion, completion_cost
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
def test_completion_sagemaker():
try:
print("testing sagemaker")
response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
input_cost_per_second=0.000420,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print(cost)
except Exception as e:
raise Exception(f"Error occurred: {e}")
```
### Usage with OpenAI Proxy Server
**Step 1: Add pricing to config.yaml** **Step 1: Add pricing to config.yaml**
```yaml ```yaml
model_list: model_list:
@ -32,3 +66,43 @@ litellm /path/to/config.yaml
**Step 3: View Spend Logs** **Step 3: View Spend Logs**
<Image img={require('../../img/spend_logs_table.png')} /> <Image img={require('../../img/spend_logs_table.png')} />
## Cost Per Token
```python
# !pip install boto3
from litellm import completion, completion_cost
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
def test_completion_sagemaker():
try:
print("testing sagemaker")
response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
input_cost_per_token=0.005,
output_cost_per_token=1,
)
# Add any assertions here to check the response
print(response)
cost = completion_cost(completion_response=response)
print(cost)
except Exception as e:
raise Exception(f"Error occurred: {e}")
```
### Usage with OpenAI Proxy Server
```yaml
model_list:
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_token: 0.000420 # 👈 key change
output_cost_per_token: 0.000420 # 👈 key change
```

View file

@ -139,13 +139,13 @@ const sidebars = {
"items": [ "items": [
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
"proxy/custom_pricing"
] ]
}, },
"proxy/deploy", "proxy/deploy",
"proxy/cli", "proxy/cli",
] ]
}, },
"proxy/custom_pricing",
"routing", "routing",
"rules", "rules",
"set_keys", "set_keys",

View file

@ -706,15 +706,16 @@ class OpenAIChatCompletion(BaseLLM):
## COMPLETION CALL ## COMPLETION CALL
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore response = openai_client.images.generate(**data, timeout=timeout) # type: ignore
response = response.model_dump() # type: ignore
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
# return response # return response
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e

View file

@ -83,6 +83,7 @@ from litellm.utils import (
TextCompletionResponse, TextCompletionResponse,
TextChoices, TextChoices,
EmbeddingResponse, EmbeddingResponse,
ImageResponse,
read_config_args, read_config_args,
Choices, Choices,
Message, Message,
@ -2987,6 +2988,7 @@ def image_generation(
else: else:
model = "dall-e-2" model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai custom_llm_provider = "openai" # default to dall-e-2 on openai
model_response._hidden_params["model"] = model
openai_params = [ openai_params = [
"user", "user",
"request_timeout", "request_timeout",

View file

@ -587,6 +587,10 @@ async def track_cost_callback(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
else:
raise Exception(
f"Model={kwargs['model']} not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")

View file

@ -74,6 +74,7 @@ class CompletionCustomHandler(
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try: try:
print(f"kwargs: {kwargs}")
self.states.append("post_api_call") self.states.append("post_api_call")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
@ -149,7 +150,14 @@ class CompletionCustomHandler(
## END TIME ## END TIME
assert isinstance(end_time, datetime) assert isinstance(end_time, datetime)
## RESPONSE OBJECT ## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse) assert isinstance(
response_obj,
Union[
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.ImageResponse,
],
)
## KWARGS ## KWARGS
assert isinstance(kwargs["model"], str) assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance( assert isinstance(kwargs["messages"], list) and isinstance(
@ -177,6 +185,7 @@ class CompletionCustomHandler(
def log_failure_event(self, kwargs, response_obj, start_time, end_time): def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
print(f"kwargs: {kwargs}")
self.states.append("sync_failure") self.states.append("sync_failure")
## START TIME ## START TIME
assert isinstance(start_time, datetime) assert isinstance(start_time, datetime)
@ -766,6 +775,52 @@ async def test_async_embedding_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success
# asyncio.run( # Image Generation
# test_async_embedding_azure_caching()
# )
# ## Test OpenAI + Sync
# def test_image_generation_openai():
# try:
# customHandler_success = CompletionCustomHandler()
# customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3"
# )
# print(f"response: {response}")
# assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success
# # test failure callback
# litellm.callbacks = [customHandler_failure]
# try:
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-4"
# )
# except:
# pass
# print(f"customHandler_failure.errors: {customHandler_failure.errors}")
# print(f"customHandler_failure.states: {customHandler_failure.states}")
# assert len(customHandler_failure.errors) == 0
# assert len(customHandler_failure.states) == 3 # pre, post, failure
# except litellm.RateLimitError as e:
# pass
# except litellm.ContentPolicyViolationError:
# pass # OpenAI randomly raises these errors - skip when they occur
# except Exception as e:
# pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai()
## Test OpenAI + Async
## Test Azure + Sync
## Test Azure + Async

View file

@ -1064,13 +1064,21 @@ class Logging:
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
if result is not None and ( ## if model in model cost map - log the response cost
isinstance(result, ModelResponse) ## else set cost to None
or isinstance(result, EmbeddingResponse) if (
result is not None
and (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
)
and result.model in litellm.model_cost
): ):
self.model_call_details["response_cost"] = litellm.completion_cost( self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=result, completion_response=result,
) )
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
if litellm.max_budget and self.stream: if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds() time_diff = (end_time - start_time).total_seconds()
@ -1084,7 +1092,7 @@ class Logging:
return start_time, end_time, result return start_time, end_time, result
except Exception as e: except Exception as e:
print_verbose(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
def success_handler( def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs

View file

@ -906,6 +906,14 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"amazon.titan-embed-text-v1": {
"max_tokens": 8192,
"output_vector_size": 1536,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0,
"litellm_provider": "bedrock",
"mode": "embedding"
},
"anthropic.claude-v1": { "anthropic.claude-v1": {
"max_tokens": 100000, "max_tokens": 100000,
"max_output_tokens": 8191, "max_output_tokens": 8191,