(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)

* add initial test for assembly ai

* start using PassthroughEndpointRouter

* migrate to lllm passthrough endpoints

* add assembly ai as a known provider

* fix PassthroughEndpointRouter

* fix set_pass_through_credentials

* working EU request to assembly ai pass through endpoint

* add e2e test assembly

* test_assemblyai_routes_with_bad_api_key

* clean up pass through endpoint router

* e2e testing for assembly ai pass through

* test assembly ai e2e testing

* delete assembly ai models

* fix code quality

* ui working assembly ai api base flow

* fix install assembly ai

* update model call details with kwargs for pass through logging

* fix tracking assembly ai model in response

* _handle_assemblyai_passthrough_logging

* fix test_initialize_deployment_for_pass_through_unsupported_provider

* TestPassthroughEndpointRouter

* _get_assembly_transcript

* fix assembly ai pt logging tests

* fix assemblyai_proxy_route

* fix _get_assembly_region_from_url
This commit is contained in:
Ishaan Jaff 2025-02-06 18:27:54 -08:00 committed by GitHub
parent 5dcb87a88b
commit 65c91cbbbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 656 additions and 79 deletions

View file

@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum):
"/azure",
"/openai",
"/assemblyai",
"/eu.assemblyai",
]
anthropic_routes = [

View file

@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
)
from litellm.secret_managers.main import get_secret_str
from .passthrough_endpoint_router import PassthroughEndpointRouter
router = APIRouter()
default_vertex_config = None
passthrough_endpoint_router = PassthroughEndpointRouter()
def create_request_copy(request: Request):
return {
@ -68,8 +72,9 @@ async def gemini_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
secret_name="GEMINI_API_KEY"
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
custom_llm_provider="gemini",
region_name=None,
)
if gemini_api_key is None:
raise Exception(
@ -126,7 +131,10 @@ async def cohere_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
cohere_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="cohere",
region_name=None,
)
## check for streaming
is_streaming_request = False
@ -175,7 +183,10 @@ async def anthropic_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
anthropic_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="anthropic",
region_name=None,
)
## check for streaming
is_streaming_request = False
@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI Pass-through", "pass-through"],
)
@router.api_route(
"/eu.assemblyai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI EU Pass-through", "pass-through"],
)
async def assemblyai_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
"""
[Docs](https://api.assemblyai.com)
"""
base_target_url = "https://api.assemblyai.com"
# Set base URL based on the route
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=str(request.url)
)
base_target_url = (
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
region=assembly_region
)
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
@ -318,7 +345,10 @@ async def assemblyai_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=assembly_region,
)
## check for streaming
is_streaming_request = False
@ -366,7 +396,10 @@ async def azure_proxy_route(
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
)
# Add or update query parameters
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
azure_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="azure",
region_name=None,
)
if azure_api_key is None:
raise Exception(
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
@ -400,7 +433,10 @@ async def openai_proxy_route(
"""
base_target_url = "https://api.openai.com"
# Add or update query parameters
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY")
openai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="openai",
region_name=None,
)
if openai_api_key is None:
raise Exception(
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."

View file

@ -1,12 +1,13 @@
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Optional, TypedDict
from typing import Literal, Optional, TypedDict
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str
language_model: str
speech_model: str
acoustic_model: str
language_code: str
status: str
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
class AssemblyAIPassthroughLoggingHandler:
def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
self.assembly_ai_base_url = "https://api.assemblyai.com"
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
"""
The base URL for the AssemblyAI API
"""
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
The maximum number of polling attempts for the AssemblyAI API.
"""
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
def assemblyai_passthrough_logging_handler(
self,
httpx_response: httpx.Response,
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
"""
from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("model", "")
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
model = response_body.get("speech_model", "")
verbose_proxy_logger.debug(
"response body %s", json.dumps(response_body, indent=4)
)
kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
transcript_id = response_body.get("id")
if transcript_id is None:
raise ValueError(
"Transcript ID is required to log the cost of the transcription"
)
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
transcript_response = self._poll_assembly_for_transcript_response(
transcript_id=transcript_id, url_route=url_route
)
verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response",
"finished polling assembly for transcript response- got transcript response %s",
json.dumps(transcript_response, indent=4),
)
if transcript_response:
cost = self.get_cost_for_assembly_transcript(transcript_response)
cost = self.get_cost_for_assembly_transcript(
speech_model=model,
transcript_response=transcript_response,
)
kwargs["response_cost"] = cost
logging_obj.model_call_details["model"] = logging_obj.model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
return {}
return dict(transcript_response)
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
def _get_assembly_transcript(
self,
transcript_id: str,
request_region: Optional[Literal["eu"]] = None,
) -> Optional[dict]:
"""
Get the transcript details from AssemblyAI API
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
Returns:
Optional[dict]: Transcript details if successful, None otherwise
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
_base_url = (
self.assembly_ai_eu_base_url
if request_region == "eu"
else self.assembly_ai_base_url
)
_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=request_region,
)
if _api_key is None:
raise ValueError("AssemblyAI API key not found")
try:
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
url = f"{_base_url}/v2/transcript/{transcript_id}"
headers = {
"Authorization": f"Bearer {self.assemblyai_api_key}",
"Authorization": f"Bearer {_api_key}",
"Content-Type": "application/json",
}
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
return response.json()
except Exception as e:
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
)
return None
def _poll_assembly_for_transcript_response(
self, transcript_id: str
self,
transcript_id: str,
url_route: Optional[str] = None,
) -> Optional[AssemblyAITranscriptResponse]:
"""
Poll the status of the transcript until it is completed or timeout (30 minutes)
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
for _ in range(
self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(transcript_id)
transcript = self._get_assembly_transcript(
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=url_route
),
transcript_id=transcript_id,
)
if transcript is None:
return None
if (
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
@staticmethod
def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse,
speech_model: str,
) -> Optional[float]:
"""
Get the cost for the assembly transcript
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
_audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None:
return None
return _audio_duration * 0.0001
_cost_per_second = (
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
speech_model=speech_model
)
)
if _cost_per_second is None:
return None
return _audio_duration * _cost_per_second
@staticmethod
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
"""
Get the cost per second for the assembly model.
Falls back to assemblyai/nano if the specific speech model info cannot be found.
"""
try:
# First try with the provided speech model
try:
model_info = litellm.get_model_info(
model=speech_model,
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass # Continue to fallback if model not found
# Fallback to assemblyai/nano if speech model info not found
try:
model_info = litellm.get_model_info(
model="assemblyai/nano",
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass
return None
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
)
return None
@staticmethod
def _should_log_request(request_method: str) -> bool:
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
"""
return request_method == "POST"
@staticmethod
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
"""
Get the region from the URL
"""
if url is None:
return None
if urlparse(url).hostname == "eu.assemblyai.com":
return "eu"
return None
@staticmethod
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
"""
Get the base URL for the AssemblyAI API
if region == "eu", return "https://api.eu.assemblyai.com"
else return "https://api.assemblyai.com"
"""
if region == "eu":
return "https://api.eu.assemblyai.com"
return "https://api.assemblyai.com"

View file

@ -0,0 +1,93 @@
from typing import Dict, Optional
from litellm._logging import verbose_logger
from litellm.secret_managers.main import get_secret_str
class PassthroughEndpointRouter:
"""
Use this class to Set/Get credentials for pass-through endpoints
"""
def __init__(self):
self.credentials: Dict[str, str] = {}
def set_pass_through_credentials(
self,
custom_llm_provider: str,
api_base: Optional[str],
api_key: Optional[str],
):
"""
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
Args:
custom_llm_provider: The provider of the pass-through endpoint
api_base: The base URL of the pass-through endpoint
api_key: The API key for the pass-through endpoint
"""
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=self._get_region_name_from_api_base(
api_base=api_base, custom_llm_provider=custom_llm_provider
),
)
if api_key is None:
raise ValueError("api_key is required for setting pass-through credentials")
self.credentials[credential_name] = api_key
def get_credentials(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> Optional[str]:
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=region_name,
)
verbose_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
)
if credential_name in self.credentials:
verbose_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name]
else:
verbose_logger.debug(
f"No credentials found for {credential_name}, looking for env variable"
)
_env_variable_name = (
self._get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider=custom_llm_provider,
)
)
return get_secret_str(_env_variable_name)
def _get_credential_name_for_provider(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> str:
if region_name is None:
return f"{custom_llm_provider.upper()}_API_KEY"
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
def _get_region_name_from_api_base(
self,
custom_llm_provider: str,
api_base: Optional[str],
) -> Optional[str]:
"""
Get the region name from the API base.
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
"""
if custom_llm_provider == "assemblyai":
if api_base and "eu" in api_base:
return "eu"
return None
@staticmethod
def _get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider: str,
) -> str:
return f"{custom_llm_provider.upper()}_API_KEY"