fix(utils.py): return 'response_cost' in completion call

Closes https://github.com/BerriAI/litellm/issues/4335
This commit is contained in:
Krrish Dholakia 2024-06-26 17:55:57 -07:00
parent 151d19960e
commit f533e1da09
4 changed files with 260 additions and 64 deletions

View file

@ -433,6 +433,7 @@ def get_custom_headers(
api_base: Optional[str] = None,
version: Optional[str] = None,
model_region: Optional[str] = None,
response_cost: Optional[Union[float, str]] = None,
fastest_response_batch_completion: Optional[bool] = None,
**kwargs,
) -> dict:
@ -443,6 +444,7 @@ def get_custom_headers(
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
"x-litellm-model-region": model_region,
"x-litellm-response-cost": str(response_cost),
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-fastest_response_batch_completion": (
@ -3048,6 +3050,7 @@ async def chat_completion(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
@ -3066,6 +3069,7 @@ async def chat_completion(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
)
@ -3095,6 +3099,7 @@ async def chat_completion(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
**additional_headers,
@ -3290,6 +3295,7 @@ async def completion(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -3304,6 +3310,7 @@ async def completion(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
)
selected_data_generator = select_data_generator(
response=response,
@ -3323,6 +3330,7 @@ async def completion(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
)
)
@ -3527,6 +3535,7 @@ async def embeddings(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastapi_response.headers.update(
get_custom_headers(
@ -3535,6 +3544,7 @@ async def embeddings(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
@ -3676,6 +3686,7 @@ async def image_generation(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastapi_response.headers.update(
get_custom_headers(
@ -3684,6 +3695,7 @@ async def image_generation(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
@ -3812,6 +3824,7 @@ async def audio_speech(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
# Printing each chunk size
async def generate(_response: HttpxBinaryResponseContent):
@ -3825,6 +3838,7 @@ async def audio_speech(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=None,
)
@ -3976,6 +3990,7 @@ async def audio_transcriptions(
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastapi_response.headers.update(
get_custom_headers(
@ -3984,6 +3999,7 @@ async def audio_transcriptions(
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)