Merge pull request #2426 from BerriAI/litellm_whisper_cost_tracking

feat: add cost tracking + caching for `/audio/transcription` calls
This commit is contained in:
Krish Dholakia 2024-03-09 19:12:06 -08:00 committed by GitHub
commit c7d0af0a2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 247 additions and 41 deletions

View file

@ -10,7 +10,7 @@
import litellm import litellm
import time, logging, asyncio import time, logging, asyncio
import json, traceback, ast, hashlib import json, traceback, ast, hashlib
from typing import Optional, Literal, List, Union, Any from typing import Optional, Literal, List, Union, Any, BinaryIO
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -765,8 +765,24 @@ class Cache:
password: Optional[str] = None, password: Optional[str] = None,
similarity_threshold: Optional[float] = None, similarity_threshold: Optional[float] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
# s3 Bucket, boto3 configuration # s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None, s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None, s3_region_name: Optional[str] = None,
@ -881,9 +897,14 @@ class Cache:
"input", "input",
"encoding_format", "encoding_format",
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
transcription_only_kwargs = [
"file",
"language",
]
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
combined_kwargs = completion_kwargs + embedding_only_kwargs combined_kwargs = (
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
)
for param in combined_kwargs: for param in combined_kwargs:
# ignore litellm params here # ignore litellm params here
if param in kwargs: if param in kwargs:
@ -915,6 +936,17 @@ class Cache:
param_value = ( param_value = (
caching_group or model_group or kwargs[param] caching_group or model_group or kwargs[param]
) # use caching_group, if set then model_group if it exists, else use kwargs["model"] ) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
elif param == "file":
metadata_file_name = kwargs.get("metadata", {}).get(
"file_name", None
)
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
"file_name", None
)
if metadata_file_name is not None:
param_value = metadata_file_name
elif litellm_params_file_name is not None:
param_value = litellm_params_file_name
else: else:
if kwargs[param] is None: if kwargs[param] is None:
continue # ignore None params continue # ignore None params
@ -1144,8 +1176,24 @@ def enable_cache(
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs, **kwargs,
): ):
""" """
@ -1193,8 +1241,24 @@ def update_cache(
port: Optional[str] = None, port: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
supported_call_types: Optional[ supported_call_types: Optional[
List[Literal["completion", "acompletion", "embedding", "aembedding"]] List[
] = ["completion", "acompletion", "embedding", "aembedding"], Literal[
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
]
]
] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
],
**kwargs, **kwargs,
): ):
""" """

View file

@ -861,7 +861,8 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -921,7 +922,8 @@ class AzureChatCompletion(BaseLLM):
}, },
original_response=stringified_response, original_response=stringified_response,
) )
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING

View file

@ -753,6 +753,7 @@ class OpenAIChatCompletion(BaseLLM):
# return response # return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -824,7 +825,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -862,7 +864,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -3295,6 +3295,7 @@ async def audio_transcriptions(
user_api_key_dict, "team_id", None user_api_key_dict, "team_id", None
) )
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
data["metadata"]["file_name"] = file.filename
### TEAM-SPECIFIC PARAMS ### ### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None: if user_api_key_dict.team_id is not None:
@ -3329,7 +3330,7 @@ async def audio_transcriptions(
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=data, data=data,
call_type="moderation", call_type="audio_transcription",
) )
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##

View file

@ -96,7 +96,11 @@ class ProxyLogging:
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
data: dict, data: dict,
call_type: Literal[ call_type: Literal[
"completion", "embeddings", "image_generation", "moderation" "completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
], ],
): ):
""" """

View file

@ -6,7 +6,12 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import time import time
import litellm import litellm
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models from litellm import (
get_max_tokens,
model_cost,
open_ai_chat_completion_models,
TranscriptionResponse,
)
import pytest import pytest
@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls():
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
assert cost > 0 assert cost > 0
def test_whisper_openai():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "openai",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(model="whisper-1", completion_response=transcription)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost
def test_whisper_azure():
litellm.set_verbose = True
transcription = TranscriptionResponse(
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
)
transcription._hidden_params = {
"model": "whisper-1",
"custom_llm_provider": "azure",
"optional_params": {},
"model_id": None,
}
_total_time_in_seconds = 3
transcription._response_ms = _total_time_in_seconds * 1000
cost = litellm.completion_cost(
model="azure/azure-whisper", completion_response=transcription
)
print(f"cost: {cost}")
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
expected_cost = round(
litellm.model_cost["whisper-1"]["output_cost_per_second"]
* _total_time_in_seconds,
5,
)
assert cost == expected_cost

View file

@ -973,6 +973,7 @@ def test_image_generation_openai():
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0 assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback # test failure callback

View file

@ -100,7 +100,7 @@ class TmpFunction:
def test_async_chat_openai_stream(): def test_async_chat_openai_stream():
try: try:
tmp_function = TmpFunction() tmp_function = TmpFunction()
# litellm.set_verbose = True litellm.set_verbose = True
litellm.success_callback = [tmp_function.async_test_logging_fn] litellm.success_callback = [tmp_function.async_test_logging_fn]
complete_streaming_response = "" complete_streaming_response = ""

View file

@ -336,6 +336,8 @@ def test_load_router_config():
"acompletion", "acompletion",
"embedding", "embedding",
"aembedding", "aembedding",
"atranscription",
"transcription",
] # init with all call types ] # init with all call types
litellm.disable_cache() litellm.disable_cache()

View file

@ -1168,6 +1168,7 @@ class Logging:
isinstance(result, ModelResponse) isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse) or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
) )
and self.stream != True and self.stream != True
): # handle streaming separately ): # handle streaming separately
@ -1203,9 +1204,6 @@ class Logging:
model=base_model, model=base_model,
) )
) )
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e: except litellm.NotFoundError as e:
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model} not found in completion cost map." f"Model={self.model} not found in completion cost map."
@ -1236,7 +1234,7 @@ class Logging:
def success_handler( def success_handler(
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
): ):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call: {cache_hit}") print_verbose(f"Logging Details LiteLLM-Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
@ -1245,7 +1243,7 @@ class Logging:
) )
# print(f"original response in success handler: {self.model_call_details['original_response']}") # print(f"original response in success handler: {self.model_call_details['original_response']}")
try: try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") print_verbose(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream and isinstance(result, ModelResponse): if self.stream and isinstance(result, ModelResponse):
@ -1268,7 +1266,7 @@ class Logging:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
if complete_streaming_response is not None: if complete_streaming_response is not None:
verbose_logger.debug( print_verbose(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details["complete_streaming_response"] = ( self.model_call_details["complete_streaming_response"] = (
@ -1615,6 +1613,14 @@ class Logging:
"aembedding", False "aembedding", False
) )
== False == False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger class ): # custom logger class
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
@ -1647,6 +1653,14 @@ class Logging:
"aembedding", False "aembedding", False
) )
== False == False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
@ -1681,6 +1695,7 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}")
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -2473,6 +2488,7 @@ def client(original_function):
and kwargs.get("aembedding", False) != True and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True and kwargs.get("aimg_generation", False) != True
and kwargs.get("atranscription", False) != True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose(f"INSIDE CHECKING CACHE")
@ -2875,6 +2891,19 @@ def client(original_function):
model_response_object=EmbeddingResponse(), model_response_object=EmbeddingResponse(),
response_type="embedding", response_type="embedding",
) )
elif call_type == CallTypes.atranscription.value and isinstance(
cached_result, dict
):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if kwargs.get("stream", False) == False: if kwargs.get("stream", False) == False:
# LOG SUCCESS # LOG SUCCESS
asyncio.create_task( asyncio.create_task(
@ -3001,6 +3030,20 @@ def client(original_function):
else: else:
return result return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ### ### POST-CALL RULES ###
post_call_processing(original_response=result, model=model) post_call_processing(original_response=result, model=model)
@ -3013,8 +3056,10 @@ def client(original_function):
) )
and (kwargs.get("cache", {}).get("no-store", False) != True) and (kwargs.get("cache", {}).get("no-store", False) != True)
): ):
if isinstance(result, litellm.ModelResponse) or isinstance( if (
result, litellm.EmbeddingResponse isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
): ):
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
@ -3058,18 +3103,7 @@ def client(original_function):
args=(result, start_time, end_time), args=(result, start_time, end_time),
).start() ).start()
# RETURN RESULT # REBUILD EMBEDDING CACHING
if hasattr(result, "_hidden_params"):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
if isinstance(result, ModelResponse) or isinstance(
result, EmbeddingResponse
):
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None and final_embedding_cached_response is not None
@ -3575,6 +3609,20 @@ def cost_per_token(
completion_tokens_cost_usd_dollar = ( completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens model_cost_ref[model]["output_cost_per_token"] * completion_tokens
) )
elif (
model_cost_ref[model].get("output_cost_per_second", None) is not None
and response_time_ms is not None
):
print_verbose(
f"For model={model} - output_cost_per_second: {model_cost_ref[model].get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_second"]
* response_time_ms
/ 1000
)
elif ( elif (
model_cost_ref[model].get("input_cost_per_second", None) is not None model_cost_ref[model].get("input_cost_per_second", None) is not None
and response_time_ms is not None and response_time_ms is not None
@ -3659,6 +3707,8 @@ def completion_cost(
"text_completion", "text_completion",
"image_generation", "image_generation",
"aimage_generation", "aimage_generation",
"transcription",
"atranscription",
] = "completion", ] = "completion",
### REGION ### ### REGION ###
custom_llm_provider=None, custom_llm_provider=None,
@ -3694,7 +3744,6 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
""" """
try: try:
if ( if (
(call_type == "aimage_generation" or call_type == "image_generation") (call_type == "aimage_generation" or call_type == "image_generation")
and model is not None and model is not None
@ -3717,10 +3766,15 @@ def completion_cost(
verbose_logger.debug( verbose_logger.debug(
f"completion_response response ms: {completion_response.get('_response_ms')} " f"completion_response response ms: {completion_response.get('_response_ms')} "
) )
model = ( model = model or completion_response.get(
model or completion_response["model"] "model", None
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
if (
completion_response._hidden_params.get("model", None) is not None
and len(completion_response._hidden_params["model"]) > 0
):
model = completion_response._hidden_params.get("model", model)
custom_llm_provider = completion_response._hidden_params.get( custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", "" "custom_llm_provider", ""
) )
@ -3801,6 +3855,7 @@ def completion_cost(
# see https://replicate.com/pricing # see https://replicate.com/pricing
elif model in litellm.replicate_models or "replicate" in model: elif model in litellm.replicate_models or "replicate" in model:
return get_replicate_completion_pricing(completion_response, total_time) return get_replicate_completion_pricing(completion_response, total_time)
( (
prompt_tokens_cost_usd_dollar, prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar,
@ -6314,6 +6369,7 @@ def convert_to_model_response_object(
stream=False, stream=False,
start_time=None, start_time=None,
end_time=None, end_time=None,
hidden_params: Optional[dict] = None,
): ):
try: try:
if response_type == "completion" and ( if response_type == "completion" and (
@ -6373,6 +6429,9 @@ def convert_to_model_response_object(
end_time - start_time end_time - start_time
).total_seconds() * 1000 ).total_seconds() * 1000
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "embedding" and ( elif response_type == "embedding" and (
model_response_object is None model_response_object is None
@ -6402,6 +6461,9 @@ def convert_to_model_response_object(
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "image_generation" and ( elif response_type == "image_generation" and (
model_response_object is None model_response_object is None
@ -6419,6 +6481,9 @@ def convert_to_model_response_object(
if "data" in response_object: if "data" in response_object:
model_response_object.data = response_object["data"] model_response_object.data = response_object["data"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
elif response_type == "audio_transcription" and ( elif response_type == "audio_transcription" and (
model_response_object is None model_response_object is None
@ -6432,6 +6497,9 @@ def convert_to_model_response_object(
if "text" in response_object: if "text" in response_object:
model_response_object.text = response_object["text"] model_response_object.text = response_object["text"]
if hidden_params is not None:
model_response_object._hidden_params = hidden_params
return model_response_object return model_response_object
except Exception as e: except Exception as e:
raise Exception(f"Invalid response object {traceback.format_exc()}") raise Exception(f"Invalid response object {traceback.format_exc()}")

View file

@ -31,7 +31,8 @@ def test_transcription():
model="whisper-1", model="whisper-1",
file=audio_file, file=audio_file,
) )
print(f"transcript: {transcript}") print(f"transcript: {transcript.model_dump()}")
print(f"transcript: {transcript._hidden_params}")
# test_transcription() # test_transcription()
@ -47,6 +48,7 @@ def test_transcription_azure():
api_version="2024-02-15-preview", api_version="2024-02-15-preview",
) )
print(f"transcript: {transcript}")
assert transcript.text is not None assert transcript.text is not None
assert isinstance(transcript.text, str) assert isinstance(transcript.text, str)