(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)

* add initial test for assembly ai

* start using PassthroughEndpointRouter

* migrate to lllm passthrough endpoints

* add assembly ai as a known provider

* fix PassthroughEndpointRouter

* fix set_pass_through_credentials

* working EU request to assembly ai pass through endpoint

* add e2e test assembly

* test_assemblyai_routes_with_bad_api_key

* clean up pass through endpoint router

* e2e testing for assembly ai pass through

* test assembly ai e2e testing

* delete assembly ai models

* fix code quality

* ui working assembly ai api base flow

* fix install assembly ai

* update model call details with kwargs for pass through logging

* fix tracking assembly ai model in response

* _handle_assemblyai_passthrough_logging

* fix test_initialize_deployment_for_pass_through_unsupported_provider

* TestPassthroughEndpointRouter

* _get_assembly_transcript

* fix assembly ai pt logging tests

* fix assemblyai_proxy_route

* fix _get_assembly_region_from_url
This commit is contained in:
Ishaan Jaff 2025-02-06 18:27:54 -08:00 committed by GitHub
parent 5dcb87a88b
commit 65c91cbbbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 656 additions and 79 deletions

View file

@ -1552,6 +1552,7 @@ jobs:
pip install "pytest-retry==1.6.3" pip install "pytest-retry==1.6.3"
pip install "pytest-mock==3.12.0" pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install "assemblyai==0.37.0"
- run: - run:
name: Build Docker image name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database . command: docker build -t my-app:latest -f ./docker/Dockerfile.database .

View file

@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum):
"/azure", "/azure",
"/openai", "/openai",
"/assemblyai", "/assemblyai",
"/eu.assemblyai",
] ]
anthropic_routes = [ anthropic_routes = [

View file

@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
) )
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from .passthrough_endpoint_router import PassthroughEndpointRouter
router = APIRouter() router = APIRouter()
default_vertex_config = None default_vertex_config = None
passthrough_endpoint_router = PassthroughEndpointRouter()
def create_request_copy(request: Request): def create_request_copy(request: Request):
return { return {
@ -68,8 +72,9 @@ async def gemini_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint) updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters # Add or update query parameters
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
secret_name="GEMINI_API_KEY" custom_llm_provider="gemini",
region_name=None,
) )
if gemini_api_key is None: if gemini_api_key is None:
raise Exception( raise Exception(
@ -126,7 +131,10 @@ async def cohere_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint) updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters # Add or update query parameters
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY") cohere_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="cohere",
region_name=None,
)
## check for streaming ## check for streaming
is_streaming_request = False is_streaming_request = False
@ -175,7 +183,10 @@ async def anthropic_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint) updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters # Add or update query parameters
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY") anthropic_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="anthropic",
region_name=None,
)
## check for streaming ## check for streaming
is_streaming_request = False is_streaming_request = False
@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI Pass-through", "pass-through"], tags=["AssemblyAI Pass-through", "pass-through"],
) )
@router.api_route(
"/eu.assemblyai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI EU Pass-through", "pass-through"],
)
async def assemblyai_proxy_route( async def assemblyai_proxy_route(
endpoint: str, endpoint: str,
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
""" """
[Docs](https://api.assemblyai.com) [Docs](https://api.assemblyai.com)
""" """
base_target_url = "https://api.assemblyai.com" # Set base URL based on the route
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=str(request.url)
)
base_target_url = (
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
region=assembly_region
)
)
encoded_endpoint = httpx.URL(endpoint).path encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction # Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"): if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint encoded_endpoint = "/" + encoded_endpoint
@ -318,7 +345,10 @@ async def assemblyai_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint) updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters # Add or update query parameters
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY") assemblyai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=assembly_region,
)
## check for streaming ## check for streaming
is_streaming_request = False is_streaming_request = False
@ -366,7 +396,10 @@ async def azure_proxy_route(
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
) )
# Add or update query parameters # Add or update query parameters
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY") azure_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="azure",
region_name=None,
)
if azure_api_key is None: if azure_api_key is None:
raise Exception( raise Exception(
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure." "Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
@ -400,7 +433,10 @@ async def openai_proxy_route(
""" """
base_target_url = "https://api.openai.com" base_target_url = "https://api.openai.com"
# Add or update query parameters # Add or update query parameters
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY") openai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="openai",
region_name=None,
)
if openai_api_key is None: if openai_api_key is None:
raise Exception( raise Exception(
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI." "Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."

View file

@ -1,12 +1,13 @@
import asyncio import asyncio
import json import json
import os
import time import time
from datetime import datetime from datetime import datetime
from typing import Optional, TypedDict from typing import Literal, Optional, TypedDict
from urllib.parse import urlparse
import httpx import httpx
import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import ( from litellm.litellm_core_utils.litellm_logging import (
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
class AssemblyAITranscriptResponse(TypedDict, total=False): class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str id: str
language_model: str speech_model: str
acoustic_model: str acoustic_model: str
language_code: str language_code: str
status: str status: str
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
class AssemblyAIPassthroughLoggingHandler: class AssemblyAIPassthroughLoggingHandler:
def __init__(self): def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com/v2" self.assembly_ai_base_url = "https://api.assemblyai.com"
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
""" """
The base URL for the AssemblyAI API The base URL for the AssemblyAI API
""" """
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
The maximum number of polling attempts for the AssemblyAI API. The maximum number of polling attempts for the AssemblyAI API.
""" """
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
def assemblyai_passthrough_logging_handler( def assemblyai_passthrough_logging_handler(
self, self,
httpx_response: httpx.Response, httpx_response: httpx.Response,
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
""" """
from ..pass_through_endpoints import pass_through_endpoint_logging from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("model", "") model = response_body.get("speech_model", "")
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4)) verbose_proxy_logger.debug(
"response body %s", json.dumps(response_body, indent=4)
)
kwargs["model"] = model kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai" kwargs["custom_llm_provider"] = "assemblyai"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
transcript_id = response_body.get("id") transcript_id = response_body.get("id")
if transcript_id is None: if transcript_id is None:
raise ValueError( raise ValueError(
"Transcript ID is required to log the cost of the transcription" "Transcript ID is required to log the cost of the transcription"
) )
transcript_response = self._poll_assembly_for_transcript_response(transcript_id) transcript_response = self._poll_assembly_for_transcript_response(
transcript_id=transcript_id, url_route=url_route
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response", "finished polling assembly for transcript response- got transcript response %s",
json.dumps(transcript_response, indent=4), json.dumps(transcript_response, indent=4),
) )
if transcript_response: if transcript_response:
cost = self.get_cost_for_assembly_transcript(transcript_response) cost = self.get_cost_for_assembly_transcript(
speech_model=model,
transcript_response=transcript_response,
)
kwargs["response_cost"] = cost kwargs["response_cost"] = cost
logging_obj.model_call_details["model"] = logging_obj.model
# Make standard logging object for Vertex AI # Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload( standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs, kwargs=kwargs,
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
return {} return {}
return dict(transcript_response) return dict(transcript_response)
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]: def _get_assembly_transcript(
self,
transcript_id: str,
request_region: Optional[Literal["eu"]] = None,
) -> Optional[dict]:
""" """
Get the transcript details from AssemblyAI API Get the transcript details from AssemblyAI API
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
Returns: Returns:
Optional[dict]: Transcript details if successful, None otherwise Optional[dict]: Transcript details if successful, None otherwise
""" """
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
_base_url = (
self.assembly_ai_eu_base_url
if request_region == "eu"
else self.assembly_ai_base_url
)
_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=request_region,
)
if _api_key is None:
raise ValueError("AssemblyAI API key not found")
try: try:
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}" url = f"{_base_url}/v2/transcript/{transcript_id}"
headers = { headers = {
"Authorization": f"Bearer {self.assemblyai_api_key}", "Authorization": f"Bearer {_api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
return response.json() return response.json()
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}") verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
)
return None return None
def _poll_assembly_for_transcript_response( def _poll_assembly_for_transcript_response(
self, transcript_id: str self,
transcript_id: str,
url_route: Optional[str] = None,
) -> Optional[AssemblyAITranscriptResponse]: ) -> Optional[AssemblyAITranscriptResponse]:
""" """
Poll the status of the transcript until it is completed or timeout (30 minutes) Poll the status of the transcript until it is completed or timeout (30 minutes)
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
for _ in range( for _ in range(
self.max_polling_attempts self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max ): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(transcript_id) transcript = self._get_assembly_transcript(
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=url_route
),
transcript_id=transcript_id,
)
if transcript is None: if transcript is None:
return None return None
if ( if (
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
@staticmethod @staticmethod
def get_cost_for_assembly_transcript( def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse, transcript_response: AssemblyAITranscriptResponse,
speech_model: str,
) -> Optional[float]: ) -> Optional[float]:
""" """
Get the cost for the assembly transcript Get the cost for the assembly transcript
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
_audio_duration = transcript_response.get("audio_duration") _audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None: if _audio_duration is None:
return None return None
return _audio_duration * 0.0001 _cost_per_second = (
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
speech_model=speech_model
)
)
if _cost_per_second is None:
return None
return _audio_duration * _cost_per_second
@staticmethod
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
"""
Get the cost per second for the assembly model.
Falls back to assemblyai/nano if the specific speech model info cannot be found.
"""
try:
# First try with the provided speech model
try:
model_info = litellm.get_model_info(
model=speech_model,
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass # Continue to fallback if model not found
# Fallback to assemblyai/nano if speech model info not found
try:
model_info = litellm.get_model_info(
model="assemblyai/nano",
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass
return None
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
)
return None
@staticmethod @staticmethod
def _should_log_request(request_method: str) -> bool: def _should_log_request(request_method: str) -> bool:
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
""" """
return request_method == "POST" return request_method == "POST"
@staticmethod
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
"""
Get the region from the URL
"""
if url is None:
return None
if urlparse(url).hostname == "eu.assemblyai.com":
return "eu"
return None
@staticmethod
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
"""
Get the base URL for the AssemblyAI API
if region == "eu", return "https://api.eu.assemblyai.com"
else return "https://api.assemblyai.com"
"""
if region == "eu":
return "https://api.eu.assemblyai.com"
return "https://api.assemblyai.com"

View file

@ -0,0 +1,93 @@
from typing import Dict, Optional
from litellm._logging import verbose_logger
from litellm.secret_managers.main import get_secret_str
class PassthroughEndpointRouter:
"""
Use this class to Set/Get credentials for pass-through endpoints
"""
def __init__(self):
self.credentials: Dict[str, str] = {}
def set_pass_through_credentials(
self,
custom_llm_provider: str,
api_base: Optional[str],
api_key: Optional[str],
):
"""
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
Args:
custom_llm_provider: The provider of the pass-through endpoint
api_base: The base URL of the pass-through endpoint
api_key: The API key for the pass-through endpoint
"""
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=self._get_region_name_from_api_base(
api_base=api_base, custom_llm_provider=custom_llm_provider
),
)
if api_key is None:
raise ValueError("api_key is required for setting pass-through credentials")
self.credentials[credential_name] = api_key
def get_credentials(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> Optional[str]:
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=region_name,
)
verbose_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
)
if credential_name in self.credentials:
verbose_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name]
else:
verbose_logger.debug(
f"No credentials found for {credential_name}, looking for env variable"
)
_env_variable_name = (
self._get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider=custom_llm_provider,
)
)
return get_secret_str(_env_variable_name)
def _get_credential_name_for_provider(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> str:
if region_name is None:
return f"{custom_llm_provider.upper()}_API_KEY"
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
def _get_region_name_from_api_base(
self,
custom_llm_provider: str,
api_base: Optional[str],
) -> Optional[str]:
"""
Get the region name from the API base.
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
"""
if custom_llm_provider == "assemblyai":
if api_base and "eu" in api_base:
return "eu"
return None
@staticmethod
def _get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider: str,
) -> str:
return f"{custom_llm_provider.upper()}_API_KEY"

View file

@ -4180,8 +4180,14 @@ class Router:
vertex_credentials=deployment.litellm_params.vertex_credentials, vertex_credentials=deployment.litellm_params.vertex_credentials,
) )
else: else:
verbose_router_logger.error( from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints" passthrough_endpoint_router,
)
passthrough_endpoint_router.set_pass_through_credentials(
custom_llm_provider=custom_llm_provider,
api_base=deployment.litellm_params.api_base,
api_key=deployment.litellm_params.api_key,
) )
pass pass
pass pass

View file

@ -1870,6 +1870,7 @@ class LlmProviders(str, Enum):
LANGFUSE = "langfuse" LANGFUSE = "langfuse"
HUMANLOOP = "humanloop" HUMANLOOP = "humanloop"
TOPAZ = "topaz" TOPAZ = "topaz"
ASSEMBLYAI = "assemblyai"
# Create a set of all provider values for quick lookup # Create a set of all provider values for quick lookup

View file

@ -1,5 +1,5 @@
""" """
This test ensures that the proxy can passthrough anthropic requests This test ensures that the proxy can passthrough requests to assemblyai
""" """
import pytest import pytest

View file

@ -41,7 +41,6 @@ from litellm.proxy.pass_through_endpoints.success_handler import (
@pytest.fixture @pytest.fixture
def assembly_handler(): def assembly_handler():
handler = AssemblyAIPassthroughLoggingHandler() handler = AssemblyAIPassthroughLoggingHandler()
handler.assemblyai_api_key = "test-key"
return handler return handler
@ -66,21 +65,27 @@ def test_should_log_request():
def test_get_assembly_transcript(assembly_handler, mock_transcript_response): def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
""" """
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id} Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
and uses the test key returned by the mocked get_credentials.
""" """
with patch("httpx.get") as mock_get: # Patch get_credentials to return "test-key"
mock_get.return_value.json.return_value = mock_transcript_response with patch(
mock_get.return_value.raise_for_status.return_value = None "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
return_value="test-key",
):
with patch("httpx.get") as mock_get:
mock_get.return_value.json.return_value = mock_transcript_response
mock_get.return_value.raise_for_status.return_value = None
transcript = assembly_handler._get_assembly_transcript("test-transcript-id") transcript = assembly_handler._get_assembly_transcript("test-transcript-id")
assert transcript == mock_transcript_response assert transcript == mock_transcript_response
mock_get.assert_called_once_with( mock_get.assert_called_once_with(
"https://api.assemblyai.com/v2/transcript/test-transcript-id", "https://api.assemblyai.com/v2/transcript/test-transcript-id",
headers={ headers={
"Authorization": "Bearer test-key", "Authorization": "Bearer test-key",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
def test_poll_assembly_for_transcript_response( def test_poll_assembly_for_transcript_response(
@ -89,18 +94,24 @@ def test_poll_assembly_for_transcript_response(
""" """
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
""" """
with patch("httpx.get") as mock_get: with patch(
mock_get.return_value.json.return_value = mock_transcript_response "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
mock_get.return_value.raise_for_status.return_value = None return_value="test-key",
):
with patch("httpx.get") as mock_get:
mock_get.return_value.json.return_value = mock_transcript_response
mock_get.return_value.raise_for_status.return_value = None
# Override polling settings for faster test # Override polling settings for faster test
assembly_handler.polling_interval = 0.01 assembly_handler.polling_interval = 0.01
assembly_handler.max_polling_attempts = 2 assembly_handler.max_polling_attempts = 2
transcript = assembly_handler._poll_assembly_for_transcript_response( transcript = assembly_handler._poll_assembly_for_transcript_response(
"test-transcript-id" "test-transcript-id",
) )
assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response) assert transcript == AssemblyAITranscriptResponse(
**mock_transcript_response
)
def test_is_assemblyai_route(): def test_is_assemblyai_route():

View file

@ -0,0 +1,134 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch, MagicMock
sys.path.insert(0, os.path.abspath("../..")) #
import unittest
from unittest.mock import patch
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
PassthroughEndpointRouter,
)
passthrough_endpoint_router = PassthroughEndpointRouter()
"""
1. Basic Usage
- Set OpenAI, AssemblyAI, Anthropic, Cohere credentials
- GET credentials from passthrough_endpoint_router
2. Basic Usage - when not using DB
- No credentials set
- call GET credentials with provider name, assert that it reads the secret from the environment variable
3. Unit test for _get_default_env_variable_name_passthrough_endpoint
"""
class TestPassthroughEndpointRouter(unittest.TestCase):
def setUp(self):
self.router = PassthroughEndpointRouter()
def test_set_and_get_credentials(self):
"""
1. Basic Usage:
- Set credentials for OpenAI, AssemblyAI, Anthropic, Cohere
- GET credentials from passthrough_endpoint_router (from the memory store when available)
"""
# OpenAI: standard (no region-specific logic)
self.router.set_pass_through_credentials("openai", None, "openai_key")
self.assertEqual(self.router.get_credentials("openai", None), "openai_key")
# AssemblyAI: using an API base that contains 'eu' should trigger regional logic.
api_base_eu = "https://api.eu.assemblyai.com"
self.router.set_pass_through_credentials(
"assemblyai", api_base_eu, "assemblyai_key"
)
# When calling get_credentials, pass the region "eu" (extracted from the API base)
self.assertEqual(
self.router.get_credentials("assemblyai", "eu"), "assemblyai_key"
)
# Anthropic: no region set
self.router.set_pass_through_credentials("anthropic", None, "anthropic_key")
self.assertEqual(
self.router.get_credentials("anthropic", None), "anthropic_key"
)
# Cohere: no region set
self.router.set_pass_through_credentials("cohere", None, "cohere_key")
self.assertEqual(self.router.get_credentials("cohere", None), "cohere_key")
def test_get_credentials_from_env(self):
"""
2. Basic Usage - when not using the database:
- No credentials set in memory
- Call get_credentials with provider name and expect it to read from the environment variable (via get_secret_str)
"""
# Patch the get_secret_str function within the router's module.
with patch(
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = "env_openai_key"
# For "openai", if credentials are not set, it should fallback to the env variable.
result = self.router.get_credentials("openai", None)
self.assertEqual(result, "env_openai_key")
mock_get_secret.assert_called_once_with("OPENAI_API_KEY")
with patch(
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = "env_cohere_key"
result = self.router.get_credentials("cohere", None)
self.assertEqual(result, "env_cohere_key")
mock_get_secret.assert_called_once_with("COHERE_API_KEY")
with patch(
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = "env_anthropic_key"
result = self.router.get_credentials("anthropic", None)
self.assertEqual(result, "env_anthropic_key")
mock_get_secret.assert_called_once_with("ANTHROPIC_API_KEY")
with patch(
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
) as mock_get_secret:
mock_get_secret.return_value = "env_azure_key"
result = self.router.get_credentials("azure", None)
self.assertEqual(result, "env_azure_key")
mock_get_secret.assert_called_once_with("AZURE_API_KEY")
def test_default_env_variable_method(self):
"""
3. Unit test for _get_default_env_variable_name_passthrough_endpoint:
- Should return the provider in uppercase followed by _API_KEY.
"""
self.assertEqual(
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
"openai"
),
"OPENAI_API_KEY",
)
self.assertEqual(
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
"assemblyai"
),
"ASSEMBLYAI_API_KEY",
)
self.assertEqual(
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
"anthropic"
),
"ANTHROPIC_API_KEY",
)
self.assertEqual(
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
"cohere"
),
"COHERE_API_KEY",
)

View file

@ -76,27 +76,6 @@ def test_initialize_deployment_for_pass_through_missing_params():
) )
def test_initialize_deployment_for_pass_through_unsupported_provider():
"""
Test initialization with an unsupported provider
"""
router = Router(model_list=[])
deployment = Deployment(
model_name="unsupported-test",
litellm_params=LiteLLM_Params(
model="unsupported/test-model",
use_in_pass_through=True,
),
)
# Should not raise an error, but log a warning
router._initialize_deployment_for_pass_through(
deployment=deployment,
custom_llm_provider="unsupported_provider",
model="unsupported/test-model",
)
def test_initialize_deployment_when_pass_through_disabled(): def test_initialize_deployment_when_pass_through_disabled():
""" """
Test that initialization simply exits when use_in_pass_through is False Test that initialization simply exits when use_in_pass_through is False

View file

@ -0,0 +1,202 @@
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
Cases to cover
1. user points api base to <proxy-base>/assemblyai
2. user points api base to <proxy-base>/asssemblyai/us
3. user points api base to <proxy-base>/assemblyai/eu
4. Bad API Key / credential - 401
"""
import time
import assemblyai as aai
import pytest
import httpx
import os
import json
TEST_MASTER_KEY = "sk-1234"
PROXY_BASE_URL = "http://0.0.0.0:4000"
US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai"
EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai"
ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY"
def _delete_all_assemblyai_models_from_db():
"""
Delete all assemblyai models from the db
"""
print("Deleting all assemblyai models from the db.......")
model_list_response = httpx.get(
url=f"{PROXY_BASE_URL}/v2/model/info",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
)
response_data = model_list_response.json()
print("model list response", json.dumps(response_data, indent=4, default=str))
# Filter for only AssemblyAI models
assemblyai_models = [
model
for model in response_data["data"]
if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai"
]
for model in assemblyai_models:
model_id = model["model_info"]["id"]
httpx.post(
url=f"{PROXY_BASE_URL}/model/delete",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={"id": model_id},
)
print("Deleted all assemblyai models from the db")
@pytest.fixture(autouse=True)
def cleanup_assemblyai_models():
"""
Fixture to clean up AssemblyAI models before and after each test
"""
# Clean up before test
_delete_all_assemblyai_models_from_db()
# Run the test
yield
# Clean up after test
_delete_all_assemblyai_models_from_db()
def test_e2e_assemblyai_passthrough():
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
"""
add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com")
virtual_key = create_virtual_key()
# make request
make_assemblyai_basic_transcribe_request(
virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL
)
pass
def test_e2e_assemblyai_passthrough_eu():
"""
Test adding a pass through assemblyai model + api key + api base to the db
wait 20 seconds
make request
"""
add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com")
virtual_key = create_virtual_key()
# make request
make_assemblyai_basic_transcribe_request(
virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL
)
pass
def test_assemblyai_routes_with_bad_api_key():
"""
Test AssemblyAI endpoints with invalid API key to ensure proper error handling
"""
bad_api_key = "sk-12222"
payload = {
"audio_url": "https://assembly.ai/wildfires.mp3",
"audio_end_at": 280,
"audio_start_from": 10,
"auto_chapters": True,
}
headers = {
"Authorization": f"Bearer {bad_api_key}",
"Content-Type": "application/json",
}
# Test EU endpoint
eu_response = httpx.post(
f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload
)
assert (
eu_response.status_code == 401
), f"Expected 401 unauthorized, got {eu_response.status_code}"
# Test US endpoint
us_response = httpx.post(
f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload
)
assert (
us_response.status_code == 401
), f"Expected 401 unauthorized, got {us_response.status_code}"
def create_virtual_key():
"""
Create a virtual key
"""
response = httpx.post(
url=f"{PROXY_BASE_URL}/key/generate",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={},
)
print(response.json())
return response.json()["token"]
def add_assembly_ai_model_to_db(
api_base: str,
):
"""
Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL
"""
print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR))
response = httpx.post(
url=f"{PROXY_BASE_URL}/model/new",
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
json={
"model_name": "assemblyai/*",
"litellm_params": {
"model": "assemblyai/*",
"custom_llm_provider": "assemblyai",
"api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR),
"api_base": api_base,
"use_in_pass_through": True,
},
"model_info": {},
},
)
print(response.json())
pass
def make_assemblyai_basic_transcribe_request(
virtual_key: str, assemblyai_base_url: str
):
print("making basic transcribe request to assemblyai passthrough")
# Replace with your API key
aai.settings.api_key = f"Bearer {virtual_key}"
aai.settings.base_url = assemblyai_base_url
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
if transcript.id:
transcript.delete_by_id(transcript.id)
else:
pytest.fail("Failed to get transcript id")
if transcript.status == aai.TranscriptStatus.error:
print(transcript.error)
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
else:
print(transcript.text)

View file

@ -1,5 +1,5 @@
import React from "react"; import React from "react";
import { Form } from "antd"; import { Form, Select } from "antd";
import { TextInput, Text } from "@tremor/react"; import { TextInput, Text } from "@tremor/react";
import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd"; import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd";
import { UploadOutlined } from "@ant-design/icons"; import { UploadOutlined } from "@ant-design/icons";
@ -72,9 +72,21 @@ const ProviderSpecificFields: React.FC<ProviderSpecificFieldsProps> = ({
</> </>
)} )}
{selectedProviderEnum === Providers.AssemblyAI && (
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="API Base"
name="api_base"
>
<Select placeholder="Select API Base">
<Select.Option value="https://api.assemblyai.com">https://api.assemblyai.com</Select.Option>
<Select.Option value="https://api.eu.assemblyai.com">https://api.eu.assemblyai.com</Select.Option>
</Select>
</Form.Item>
)}
{(selectedProviderEnum === Providers.Azure || {(selectedProviderEnum === Providers.Azure ||
selectedProviderEnum === Providers.OpenAI_Compatible || selectedProviderEnum === Providers.OpenAI_Compatible
selectedProviderEnum === Providers.AssemblyAI
) && ( ) && (
<Form.Item <Form.Item
rules={[{ required: true, message: "Required" }]} rules={[{ required: true, message: "Required" }]}