(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)

* add initial test for assembly ai

* start using PassthroughEndpointRouter

* migrate to lllm passthrough endpoints

* add assembly ai as a known provider

* fix PassthroughEndpointRouter

* fix set_pass_through_credentials

* working EU request to assembly ai pass through endpoint

* add e2e test assembly

* test_assemblyai_routes_with_bad_api_key

* clean up pass through endpoint router

* e2e testing for assembly ai pass through

* test assembly ai e2e testing

* delete assembly ai models

* fix code quality

* ui working assembly ai api base flow

* fix install assembly ai

* update model call details with kwargs for pass through logging

* fix tracking assembly ai model in response

* _handle_assemblyai_passthrough_logging

* fix test_initialize_deployment_for_pass_through_unsupported_provider

* TestPassthroughEndpointRouter

* _get_assembly_transcript

* fix assembly ai pt logging tests

* fix assemblyai_proxy_route

* fix _get_assembly_region_from_url
This commit is contained in:
Ishaan Jaff 2025-02-06 18:27:54 -08:00 committed by GitHub
parent 5dcb87a88b
commit 65c91cbbbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 656 additions and 79 deletions

View file

@ -1,12 +1,13 @@
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Optional, TypedDict
from typing import Literal, Optional, TypedDict
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str
language_model: str
speech_model: str
acoustic_model: str
language_code: str
status: str
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
class AssemblyAIPassthroughLoggingHandler:
def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
self.assembly_ai_base_url = "https://api.assemblyai.com"
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
"""
The base URL for the AssemblyAI API
"""
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
The maximum number of polling attempts for the AssemblyAI API.
"""
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
def assemblyai_passthrough_logging_handler(
self,
httpx_response: httpx.Response,
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
"""
from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("model", "")
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
model = response_body.get("speech_model", "")
verbose_proxy_logger.debug(
"response body %s", json.dumps(response_body, indent=4)
)
kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
transcript_id = response_body.get("id")
if transcript_id is None:
raise ValueError(
"Transcript ID is required to log the cost of the transcription"
)
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
transcript_response = self._poll_assembly_for_transcript_response(
transcript_id=transcript_id, url_route=url_route
)
verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response",
"finished polling assembly for transcript response- got transcript response %s",
json.dumps(transcript_response, indent=4),
)
if transcript_response:
cost = self.get_cost_for_assembly_transcript(transcript_response)
cost = self.get_cost_for_assembly_transcript(
speech_model=model,
transcript_response=transcript_response,
)
kwargs["response_cost"] = cost
logging_obj.model_call_details["model"] = logging_obj.model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
return {}
return dict(transcript_response)
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
def _get_assembly_transcript(
self,
transcript_id: str,
request_region: Optional[Literal["eu"]] = None,
) -> Optional[dict]:
"""
Get the transcript details from AssemblyAI API
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
Returns:
Optional[dict]: Transcript details if successful, None otherwise
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
_base_url = (
self.assembly_ai_eu_base_url
if request_region == "eu"
else self.assembly_ai_base_url
)
_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=request_region,
)
if _api_key is None:
raise ValueError("AssemblyAI API key not found")
try:
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
url = f"{_base_url}/v2/transcript/{transcript_id}"
headers = {
"Authorization": f"Bearer {self.assemblyai_api_key}",
"Authorization": f"Bearer {_api_key}",
"Content-Type": "application/json",
}
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
return response.json()
except Exception as e:
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
)
return None
def _poll_assembly_for_transcript_response(
self, transcript_id: str
self,
transcript_id: str,
url_route: Optional[str] = None,
) -> Optional[AssemblyAITranscriptResponse]:
"""
Poll the status of the transcript until it is completed or timeout (30 minutes)
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
for _ in range(
self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(transcript_id)
transcript = self._get_assembly_transcript(
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=url_route
),
transcript_id=transcript_id,
)
if transcript is None:
return None
if (
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
@staticmethod
def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse,
speech_model: str,
) -> Optional[float]:
"""
Get the cost for the assembly transcript
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
_audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None:
return None
return _audio_duration * 0.0001
_cost_per_second = (
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
speech_model=speech_model
)
)
if _cost_per_second is None:
return None
return _audio_duration * _cost_per_second
@staticmethod
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
"""
Get the cost per second for the assembly model.
Falls back to assemblyai/nano if the specific speech model info cannot be found.
"""
try:
# First try with the provided speech model
try:
model_info = litellm.get_model_info(
model=speech_model,
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass # Continue to fallback if model not found
# Fallback to assemblyai/nano if speech model info not found
try:
model_info = litellm.get_model_info(
model="assemblyai/nano",
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass
return None
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
)
return None
@staticmethod
def _should_log_request(request_method: str) -> bool:
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
"""
return request_method == "POST"
@staticmethod
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
"""
Get the region from the URL
"""
if url is None:
return None
if urlparse(url).hostname == "eu.assemblyai.com":
return "eu"
return None
@staticmethod
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
"""
Get the base URL for the AssemblyAI API
if region == "eu", return "https://api.eu.assemblyai.com"
else return "https://api.assemblyai.com"
"""
if region == "eu":
return "https://api.eu.assemblyai.com"
return "https://api.assemblyai.com"