mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)
* add initial test for assembly ai * start using PassthroughEndpointRouter * migrate to lllm passthrough endpoints * add assembly ai as a known provider * fix PassthroughEndpointRouter * fix set_pass_through_credentials * working EU request to assembly ai pass through endpoint * add e2e test assembly * test_assemblyai_routes_with_bad_api_key * clean up pass through endpoint router * e2e testing for assembly ai pass through * test assembly ai e2e testing * delete assembly ai models * fix code quality * ui working assembly ai api base flow * fix install assembly ai * update model call details with kwargs for pass through logging * fix tracking assembly ai model in response * _handle_assemblyai_passthrough_logging * fix test_initialize_deployment_for_pass_through_unsupported_provider * TestPassthroughEndpointRouter * _get_assembly_transcript * fix assembly ai pt logging tests * fix assemblyai_proxy_route * fix _get_assembly_region_from_url
This commit is contained in:
parent
5dcb87a88b
commit
65c91cbbbc
13 changed files with 656 additions and 79 deletions
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, TypedDict
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
|
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
|
|||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
language_model: str
|
||||
speech_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
|
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
|
|||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
|
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("model", "")
|
||||
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
|
||||
model = response_body.get("speech_model", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"response body %s", json.dumps(response_body, indent=4)
|
||||
)
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(
|
||||
transcript_id=transcript_id, url_route=url_route
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response",
|
||||
"finished polling assembly for transcript response- got transcript response %s",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(transcript_response)
|
||||
cost = self.get_cost_for_assembly_transcript(
|
||||
speech_model=model,
|
||||
transcript_response=transcript_response,
|
||||
)
|
||||
kwargs["response_cost"] = cost
|
||||
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
|
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
|
||||
def _get_assembly_transcript(
|
||||
self,
|
||||
transcript_id: str,
|
||||
request_region: Optional[Literal["eu"]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
|
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
_base_url = (
|
||||
self.assembly_ai_eu_base_url
|
||||
if request_region == "eu"
|
||||
else self.assembly_ai_base_url
|
||||
)
|
||||
_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=request_region,
|
||||
)
|
||||
if _api_key is None:
|
||||
raise ValueError("AssemblyAI API key not found")
|
||||
try:
|
||||
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
|
||||
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.assemblyai_api_key}",
|
||||
"Authorization": f"Bearer {_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self, transcript_id: str
|
||||
self,
|
||||
transcript_id: str,
|
||||
url_route: Optional[str] = None,
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
|
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(transcript_id)
|
||||
transcript = self._get_assembly_transcript(
|
||||
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=url_route
|
||||
),
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
|
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
speech_model: str,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
|
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
return _audio_duration * 0.0001
|
||||
_cost_per_second = (
|
||||
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||
speech_model=speech_model
|
||||
)
|
||||
)
|
||||
if _cost_per_second is None:
|
||||
return None
|
||||
return _audio_duration * _cost_per_second
|
||||
|
||||
@staticmethod
|
||||
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||
"""
|
||||
Get the cost per second for the assembly model.
|
||||
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||
"""
|
||||
try:
|
||||
# First try with the provided speech model
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=speech_model,
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass # Continue to fallback if model not found
|
||||
|
||||
# Fallback to assemblyai/nano if speech model info not found
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model="assemblyai/nano",
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
|
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||
"""
|
||||
Get the region from the URL
|
||||
"""
|
||||
if url is None:
|
||||
return None
|
||||
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||
"""
|
||||
Get the base URL for the AssemblyAI API
|
||||
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||
else return "https://api.assemblyai.com"
|
||||
"""
|
||||
if region == "eu":
|
||||
return "https://api.eu.assemblyai.com"
|
||||
return "https://api.assemblyai.com"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue