(Feat) - New pass through add assembly ai passthrough endpoints (#8220)

* add assembly ai pass through request

* fix assembly pass through

* fix test_assemblyai_basic_transcribe

* fix assemblyai auth check

* test_assemblyai_transcribe_with_non_admin_key

* working assembly ai test

* working assembly ai proxy route

* use helper func to pass through logging

* clean up logging assembly ai

* test: update test to handle gemini token counter change

* fix(factory.py): fix bedrock http:// handling

* add unit testing for assembly pt handler

* docs assembly ai pass through endpoint

* fix proxy_pass_through_endpoint_tests

* fix standard_passthrough_logging_object

* fix ASSEMBLYAI_API_KEY

* test test_assemblyai_proxy_route_basic_post

* test_assemblyai_proxy_route_get_transcript

* fix is is_assemblyai_route

* test_is_assemblyai_route

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2025-02-03 21:54:32 -08:00 committed by GitHub
parent 0ec5205711
commit adeb09e091
9 changed files with 731 additions and 16 deletions

View file

@ -1594,6 +1594,7 @@ jobs:
pip install "google-cloud-aiplatform==1.43.0" pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp pip install aiohttp
pip install "openai==1.54.0 " pip install "openai==1.54.0 "
pip install "assemblyai==0.37.0"
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install "pydantic==2.7.1" pip install "pydantic==2.7.1"
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
@ -1626,6 +1627,7 @@ jobs:
-e OPENAI_API_KEY=$OPENAI_API_KEY \ -e OPENAI_API_KEY=$OPENAI_API_KEY \
-e GEMINI_API_KEY=$GEMINI_API_KEY \ -e GEMINI_API_KEY=$GEMINI_API_KEY \
-e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \
-e ASSEMBLYAI_API_KEY=$ASSEMBLYAI_API_KEY \
-e USE_DDTRACE=True \ -e USE_DDTRACE=True \
-e DD_API_KEY=$DD_API_KEY \ -e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \ -e DD_SITE=$DD_SITE \

View file

@ -0,0 +1,55 @@
# Assembly AI
Pass-through endpoints for Assembly AI - call Assembly AI endpoints, in native format (no translation).
| Feature | Supported | Notes |
|-------|-------|-------|
| Cost Tracking | ✅ | works across all integrations |
| Logging | ✅ | works across all integrations |
Supports **ALL** Assembly AI Endpoints
[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference)
## Quick Start
Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts)
1. Add Assembly AI API Key to your environment
```bash
export ASSEMBLYAI_API_KEY=""
```
2. Start LiteLLM Proxy
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
Let's call the Assembly AI `/v2/transcripts` endpoint
```python
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai
aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}"
aai.settings.base_url = LITELLM_PROXY_BASE_URL
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
```

View file

@ -305,6 +305,7 @@ const sidebars = {
"pass_through/cohere", "pass_through/cohere",
"pass_through/anthropic_completion", "pass_through/anthropic_completion",
"pass_through/bedrock", "pass_through/bedrock",
"pass_through/assembly_ai",
"pass_through/langfuse", "pass_through/langfuse",
"proxy/pass_through", "proxy/pass_through",
], ],

View file

@ -247,6 +247,7 @@ class LiteLLMRoutes(enum.Enum):
"/langfuse", "/langfuse",
"/azure", "/azure",
"/openai", "/openai",
"/assemblyai",
] ]
anthropic_routes = [ anthropic_routes = [

View file

@ -292,6 +292,58 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
return False return False
@router.api_route(
"/assemblyai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI Pass-through", "pass-through"],
)
async def assemblyai_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://api.assemblyai.com)
"""
base_target_url = "https://api.assemblyai.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
## check for streaming
is_streaming_request = False
# assemblyai is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "{}".format(assemblyai_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route( @router.api_route(
"/azure/{endpoint:path}", "/azure/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -0,0 +1,222 @@
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Optional, TypedDict
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str
language_model: str
acoustic_model: str
language_code: str
status: str
audio_duration: float
class AssemblyAIPassthroughLoggingHandler:
def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
"""
The base URL for the AssemblyAI API
"""
self.polling_interval: float = 10
"""
The polling interval for the AssemblyAI API.
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
"""
self.max_polling_attempts = 180
"""
The maximum number of polling attempts for the AssemblyAI API.
"""
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
def assemblyai_passthrough_logging_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
"""
executor.submit(
self._handle_assemblyai_passthrough_logging,
httpx_response,
response_body,
logging_obj,
url_route,
result,
start_time,
end_time,
cache_hit,
**kwargs,
)
def _handle_assemblyai_passthrough_logging(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Handles logging for AssemblyAI successful passthrough requests
"""
from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("model", "")
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai"
transcript_id = response_body.get("id")
if transcript_id is None:
raise ValueError(
"Transcript ID is required to log the cost of the transcription"
)
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response",
json.dumps(transcript_response, indent=4),
)
if transcript_response:
cost = self.get_cost_for_assembly_transcript(transcript_response)
kwargs["response_cost"] = cost
logging_obj.model_call_details["model"] = logging_obj.model
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=transcript_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
verbose_proxy_logger.debug(
"standard_passthrough_logging_object %s",
json.dumps(passthrough_logging_payload, indent=4),
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
asyncio.run(
pass_through_endpoint_logging._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=self._get_response_to_log(
transcript_response
),
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
pass
def _get_response_to_log(
self, transcript_response: Optional[AssemblyAITranscriptResponse]
) -> dict:
if transcript_response is None:
return {}
return dict(transcript_response)
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
"""
Get the transcript details from AssemblyAI API
Args:
response_body (dict): Response containing the transcript ID
Returns:
Optional[dict]: Transcript details if successful, None otherwise
"""
try:
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
headers = {
"Authorization": f"Bearer {self.assemblyai_api_key}",
"Content-Type": "application/json",
}
response = httpx.get(url, headers=headers)
response.raise_for_status()
return response.json()
except Exception as e:
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
return None
def _poll_assembly_for_transcript_response(
self, transcript_id: str
) -> Optional[AssemblyAITranscriptResponse]:
"""
Poll the status of the transcript until it is completed or timeout (30 minutes)
"""
for _ in range(
self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(transcript_id)
if transcript is None:
return None
if (
transcript.get("status") == "completed"
or transcript.get("status") == "error"
):
return AssemblyAITranscriptResponse(**transcript)
time.sleep(self.polling_interval)
return None
@staticmethod
def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse,
) -> Optional[float]:
"""
Get the cost for the assembly transcript
"""
_audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None:
return None
return _audio_duration * 0.0001
@staticmethod
def _should_log_request(request_method: str) -> bool:
"""
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
"""
return request_method == "POST"

View file

@ -1,6 +1,7 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional, Union
from urllib.parse import urlparse
import httpx import httpx
@ -12,6 +13,9 @@ from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler, AnthropicPassthroughLoggingHandler,
) )
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import ( from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler, VertexPassthroughLoggingHandler,
) )
@ -28,6 +32,48 @@ class PassThroughEndpointLogging:
# Anthropic # Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
async def pass_through_async_success_handler( async def pass_through_async_success_handler(
self, self,
httpx_response: httpx.Response, httpx_response: httpx.Response,
@ -79,28 +125,39 @@ class PassThroughEndpointLogging:
anthropic_passthrough_logging_handler_result["result"] anthropic_passthrough_logging_handler_result["result"]
) )
kwargs = anthropic_passthrough_logging_handler_result["kwargs"] kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
if standard_logging_response_object is None: if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject( standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text response=httpx_response.text
) )
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object, # Positional argument 1
start_time, # Positional argument 2
end_time, # Positional argument 3
cache_hit, # Positional argument 4
**kwargs, # Unpacked keyword arguments
)
await logging_obj.async_success_handler( await self._handle_logging(
result=( logging_obj=logging_obj,
json.dumps(result) standard_logging_response_object=standard_logging_response_object,
if isinstance(result, dict) result=result,
else standard_logging_response_object
),
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
cache_hit=False, cache_hit=cache_hit,
**kwargs, **kwargs,
) )
@ -115,3 +172,11 @@ class PassThroughEndpointLogging:
if route in url_route: if route in url_route:
return True return True
return False return False
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False

View file

@ -0,0 +1,95 @@
"""
This test ensures that the proxy can passthrough anthropic requests
"""
import pytest
import assemblyai as aai
import aiohttp
import asyncio
import time
TEST_MASTER_KEY = "sk-1234"
TEST_BASE_URL = "http://0.0.0.0:4000/assemblyai"
def test_assemblyai_basic_transcribe():
print("making basic transcribe request to assemblyai passthrough")
# Replace with your API key
aai.settings.api_key = f"Bearer {TEST_MASTER_KEY}"
aai.settings.base_url = TEST_BASE_URL
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
if transcript.id:
transcript.delete_by_id(transcript.id)
else:
pytest.fail("Failed to get transcript id")
if transcript.status == aai.TranscriptStatus.error:
print(transcript.error)
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
else:
print(transcript.text)
async def generate_key(calling_key: str) -> str:
"""Helper function to generate a new API key"""
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": f"Bearer {calling_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json={}) as response:
if response.status == 200:
data = await response.json()
return data.get("key")
raise Exception(f"Failed to generate key: {response.status}")
@pytest.mark.asyncio
async def test_assemblyai_transcribe_with_non_admin_key():
# Generate a non-admin key using the helper
non_admin_key = await generate_key(TEST_MASTER_KEY)
print(f"Generated non-admin key: {non_admin_key}")
# Use the non-admin key to transcribe
# Replace with your API key
aai.settings.api_key = f"Bearer {non_admin_key}"
aai.settings.base_url = TEST_BASE_URL
# URL of the file to transcribe
FILE_URL = "https://assembly.ai/wildfires.mp3"
# You can also transcribe a local file by passing in a file path
# FILE_URL = './path/to/file.mp3'
request_start_time = time.time()
transcriber = aai.Transcriber()
transcript = transcriber.transcribe(FILE_URL)
print(transcript)
print(transcript.id)
if transcript.id:
transcript.delete_by_id(transcript.id)
else:
pytest.fail("Failed to get transcript id")
if transcript.status == aai.TranscriptStatus.error:
print(transcript.error)
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
else:
print(transcript.text)
request_end_time = time.time()
print(f"Request took {request_end_time - request_start_time} seconds")

View file

@ -0,0 +1,222 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import httpx
import pytest
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import httpx
import pytest
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
AssemblyAITranscriptResponse,
)
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
@pytest.fixture
def assembly_handler():
handler = AssemblyAIPassthroughLoggingHandler()
handler.assemblyai_api_key = "test-key"
return handler
@pytest.fixture
def mock_transcript_response():
return {
"id": "test-transcript-id",
"language_model": "default",
"acoustic_model": "default",
"language_code": "en",
"status": "completed",
"audio_duration": 100.0,
}
def test_should_log_request():
handler = AssemblyAIPassthroughLoggingHandler()
assert handler._should_log_request("POST") == True
assert handler._should_log_request("GET") == False
def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
"""
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
"""
with patch("httpx.get") as mock_get:
mock_get.return_value.json.return_value = mock_transcript_response
mock_get.return_value.raise_for_status.return_value = None
transcript = assembly_handler._get_assembly_transcript("test-transcript-id")
assert transcript == mock_transcript_response
mock_get.assert_called_once_with(
"https://api.assemblyai.com/v2/transcript/test-transcript-id",
headers={
"Authorization": "Bearer test-key",
"Content-Type": "application/json",
},
)
def test_poll_assembly_for_transcript_response(
assembly_handler, mock_transcript_response
):
"""
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
"""
with patch("httpx.get") as mock_get:
mock_get.return_value.json.return_value = mock_transcript_response
mock_get.return_value.raise_for_status.return_value = None
# Override polling settings for faster test
assembly_handler.polling_interval = 0.01
assembly_handler.max_polling_attempts = 2
transcript = assembly_handler._poll_assembly_for_transcript_response(
"test-transcript-id"
)
assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response)
@pytest.fixture
def mock_request():
request = Mock()
request.method = "POST"
request.headers = {}
request.url = httpx.URL("http://test.com/test")
return request
@pytest.fixture
def mock_response():
return Mock()
@pytest.fixture
def mock_user_api_key_dict():
return {"api_key": "test-key"}
@patch("litellm.utils.get_secret")
@patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route"
)
@pytest.mark.asyncio()
async def test_assemblyai_proxy_route_basic_post(
mock_create_route,
mock_get_secret,
mock_request,
mock_response,
mock_user_api_key_dict,
):
"""Test basic POST request handling for AssemblyAI proxy route"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
assemblyai_proxy_route,
)
# Setup mocks
mock_get_secret.return_value = "test-assemblyai-key"
mock_request.json = AsyncMock(return_value={"text": "test"})
mock_endpoint_func = AsyncMock(return_value={"result": "success"})
mock_create_route.return_value = mock_endpoint_func
result = await assemblyai_proxy_route(
endpoint="v2/transcript",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
assert result == {"result": "success"}
mock_create_route.assert_called_once_with(
endpoint="v2/transcript",
target="https://api.assemblyai.com/v2/transcript",
custom_headers={"Authorization": "test-assemblyai-key"},
)
@patch("litellm.utils.get_secret")
@patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route"
)
@pytest.mark.asyncio()
async def test_assemblyai_proxy_route_get_transcript(
mock_create_route,
mock_get_secret,
mock_request,
mock_response,
mock_user_api_key_dict,
):
"""Test GET request handling for retrieving a specific transcript from AssemblyAI"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
assemblyai_proxy_route,
)
# Setup mocks
mock_get_secret.return_value = "test-assemblyai-key"
mock_request.method = "GET"
mock_endpoint_func = AsyncMock(
return_value={"id": "test-transcript-id", "status": "completed"}
)
mock_create_route.return_value = mock_endpoint_func
result = await assemblyai_proxy_route(
endpoint="v2/transcript/test-transcript-id",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
assert result == {"id": "test-transcript-id", "status": "completed"}
mock_create_route.assert_called_once_with(
endpoint="v2/transcript/test-transcript-id",
target="https://api.assemblyai.com/v2/transcript/test-transcript-id",
custom_headers={"Authorization": "test-assemblyai-key"},
)
def test_is_assemblyai_route():
"""
Test that the is_assemblyai_route method correctly identifies AssemblyAI routes
"""
handler = PassThroughEndpointLogging()
# Test positive cases
assert (
handler.is_assemblyai_route("https://api.assemblyai.com/v2/transcript") == True
)
assert handler.is_assemblyai_route("https://api.assemblyai.com/other/path") == True
assert handler.is_assemblyai_route("https://api.assemblyai.com/transcript") == True
# Test negative cases
assert handler.is_assemblyai_route("https://example.com/other") == False
assert (
handler.is_assemblyai_route("https://api.openai.com/v1/chat/completions")
== False
)
assert handler.is_assemblyai_route("") == False