From adeb09e091750ad43aa4dc9dbd7c16be93ad12dd Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 3 Feb 2025 21:54:32 -0800 Subject: [PATCH] (Feat) - New pass through add assembly ai passthrough endpoints (#8220) * add assembly ai pass through request * fix assembly pass through * fix test_assemblyai_basic_transcribe * fix assemblyai auth check * test_assemblyai_transcribe_with_non_admin_key * working assembly ai test * working assembly ai proxy route * use helper func to pass through logging * clean up logging assembly ai * test: update test to handle gemini token counter change * fix(factory.py): fix bedrock http:// handling * add unit testing for assembly pt handler * docs assembly ai pass through endpoint * fix proxy_pass_through_endpoint_tests * fix standard_passthrough_logging_object * fix ASSEMBLYAI_API_KEY * test test_assemblyai_proxy_route_basic_post * test_assemblyai_proxy_route_get_transcript * fix is is_assemblyai_route * test_is_assemblyai_route --------- Co-authored-by: Krrish Dholakia --- .circleci/config.yml | 2 + .../docs/pass_through/assembly_ai.md | 55 +++++ docs/my-website/sidebars.js | 1 + litellm/proxy/_types.py | 1 + .../llm_passthrough_endpoints.py | 52 ++++ .../assembly_passthrough_logging_handler.py | 222 ++++++++++++++++++ .../pass_through_endpoints/success_handler.py | 97 ++++++-- tests/pass_through_tests/test_assembly_ai.py | 95 ++++++++ .../test_assemblyai_unit_tests_passthrough.py | 222 ++++++++++++++++++ 9 files changed, 731 insertions(+), 16 deletions(-) create mode 100644 docs/my-website/docs/pass_through/assembly_ai.md create mode 100644 litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py create mode 100644 tests/pass_through_tests/test_assembly_ai.py create mode 100644 tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 9acffc9625..1766736a9f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1594,6 +1594,7 @@ jobs: pip install "google-cloud-aiplatform==1.43.0" pip install aiohttp pip install "openai==1.54.0 " + pip install "assemblyai==0.37.0" python -m pip install --upgrade pip pip install "pydantic==2.7.1" pip install "pytest==7.3.1" @@ -1626,6 +1627,7 @@ jobs: -e OPENAI_API_KEY=$OPENAI_API_KEY \ -e GEMINI_API_KEY=$GEMINI_API_KEY \ -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ + -e ASSEMBLYAI_API_KEY=$ASSEMBLYAI_API_KEY \ -e USE_DDTRACE=True \ -e DD_API_KEY=$DD_API_KEY \ -e DD_SITE=$DD_SITE \ diff --git a/docs/my-website/docs/pass_through/assembly_ai.md b/docs/my-website/docs/pass_through/assembly_ai.md new file mode 100644 index 0000000000..5f7f1c07e6 --- /dev/null +++ b/docs/my-website/docs/pass_through/assembly_ai.md @@ -0,0 +1,55 @@ +# Assembly AI + +Pass-through endpoints for Assembly AI - call Assembly AI endpoints, in native format (no translation). + +| Feature | Supported | Notes | +|-------|-------|-------| +| Cost Tracking | ✅ | works across all integrations | +| Logging | ✅ | works across all integrations | + + +Supports **ALL** Assembly AI Endpoints + +[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference) + +## Quick Start + +Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts) + +1. Add Assembly AI API Key to your environment + +```bash +export ASSEMBLYAI_API_KEY="" +``` + +2. Start LiteLLM Proxy + +```bash +litellm + +# RUNNING on http://0.0.0.0:4000 +``` + +3. Test it! + +Let's call the Assembly AI `/v2/transcripts` endpoint + +```python +LITELLM_VIRTUAL_KEY = "sk-1234" # +LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # /assemblyai + +aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}" +aai.settings.base_url = LITELLM_PROXY_BASE_URL + +# URL of the file to transcribe +FILE_URL = "https://assembly.ai/wildfires.mp3" + +# You can also transcribe a local file by passing in a file path +# FILE_URL = './path/to/file.mp3' + +transcriber = aai.Transcriber() +transcript = transcriber.transcribe(FILE_URL) +print(transcript) +print(transcript.id) +``` + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 5324d50ed4..9ed5f246bf 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -305,6 +305,7 @@ const sidebars = { "pass_through/cohere", "pass_through/anthropic_completion", "pass_through/bedrock", + "pass_through/assembly_ai", "pass_through/langfuse", "proxy/pass_through", ], diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 2117ee9a75..c449a21b02 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -247,6 +247,7 @@ class LiteLLMRoutes(enum.Enum): "/langfuse", "/azure", "/openai", + "/assemblyai", ] anthropic_routes = [ diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index d22f7121e6..dce20b9775 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -292,6 +292,58 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool: return False +@router.api_route( + "/assemblyai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["AssemblyAI Pass-through", "pass-through"], +) +async def assemblyai_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Docs](https://api.assemblyai.com) + """ + base_target_url = "https://api.assemblyai.com" + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY") + + ## check for streaming + is_streaming_request = False + # assemblyai is streaming when 'stream' = True is in the body + if request.method == "POST": + _request_body = await request.json() + if _request_body.get("stream"): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={"Authorization": "{}".format(assemblyai_api_key)}, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + + return received_value + + @router.api_route( "/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py new file mode 100644 index 0000000000..bf0f0f73a5 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py @@ -0,0 +1,222 @@ +import asyncio +import json +import os +import time +from datetime import datetime +from typing import Optional, TypedDict + +import httpx + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) +from litellm.litellm_core_utils.thread_pool_executor import executor +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload + + +class AssemblyAITranscriptResponse(TypedDict, total=False): + id: str + language_model: str + acoustic_model: str + language_code: str + status: str + audio_duration: float + + +class AssemblyAIPassthroughLoggingHandler: + def __init__(self): + self.assembly_ai_base_url = "https://api.assemblyai.com/v2" + """ + The base URL for the AssemblyAI API + """ + + self.polling_interval: float = 10 + """ + The polling interval for the AssemblyAI API. + litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript. + """ + + self.max_polling_attempts = 180 + """ + The maximum number of polling attempts for the AssemblyAI API. + """ + + self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY") + + def assemblyai_passthrough_logging_handler( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit. + """ + executor.submit( + self._handle_assemblyai_passthrough_logging, + httpx_response, + response_body, + logging_obj, + url_route, + result, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + def _handle_assemblyai_passthrough_logging( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """ + Handles logging for AssemblyAI successful passthrough requests + """ + from ..pass_through_endpoints import pass_through_endpoint_logging + + model = response_body.get("model", "") + verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4)) + kwargs["model"] = model + kwargs["custom_llm_provider"] = "assemblyai" + + transcript_id = response_body.get("id") + if transcript_id is None: + raise ValueError( + "Transcript ID is required to log the cost of the transcription" + ) + transcript_response = self._poll_assembly_for_transcript_response(transcript_id) + verbose_proxy_logger.debug( + "finished polling assembly for transcript response- got transcript response", + json.dumps(transcript_response, indent=4), + ) + if transcript_response: + cost = self.get_cost_for_assembly_transcript(transcript_response) + kwargs["response_cost"] = cost + + logging_obj.model_call_details["model"] = logging_obj.model + + # Make standard logging object for Vertex AI + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=transcript_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + + verbose_proxy_logger.debug( + "standard_passthrough_logging_object %s", + json.dumps(passthrough_logging_payload, indent=4), + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", json.dumps(standard_logging_object, indent=4) + ) + asyncio.run( + pass_through_endpoint_logging._handle_logging( + logging_obj=logging_obj, + standard_logging_response_object=self._get_response_to_log( + transcript_response + ), + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + ) + + pass + + def _get_response_to_log( + self, transcript_response: Optional[AssemblyAITranscriptResponse] + ) -> dict: + if transcript_response is None: + return {} + return dict(transcript_response) + + def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]: + """ + Get the transcript details from AssemblyAI API + + Args: + response_body (dict): Response containing the transcript ID + + Returns: + Optional[dict]: Transcript details if successful, None otherwise + """ + try: + url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}" + headers = { + "Authorization": f"Bearer {self.assemblyai_api_key}", + "Content-Type": "application/json", + } + + response = httpx.get(url, headers=headers) + response.raise_for_status() + + return response.json() + except Exception as e: + verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}") + return None + + def _poll_assembly_for_transcript_response( + self, transcript_id: str + ) -> Optional[AssemblyAITranscriptResponse]: + """ + Poll the status of the transcript until it is completed or timeout (30 minutes) + """ + for _ in range( + self.max_polling_attempts + ): # 180 attempts * 10s = 30 minutes max + transcript = self._get_assembly_transcript(transcript_id) + if transcript is None: + return None + if ( + transcript.get("status") == "completed" + or transcript.get("status") == "error" + ): + return AssemblyAITranscriptResponse(**transcript) + time.sleep(self.polling_interval) + return None + + @staticmethod + def get_cost_for_assembly_transcript( + transcript_response: AssemblyAITranscriptResponse, + ) -> Optional[float]: + """ + Get the cost for the assembly transcript + """ + _audio_duration = transcript_response.get("audio_duration") + if _audio_duration is None: + return None + return _audio_duration * 0.0001 + + @staticmethod + def _should_log_request(request_method: str) -> bool: + """ + only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost + """ + return request_method == "POST" diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 6f112aed1f..02e81566e8 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -1,6 +1,7 @@ import json from datetime import datetime -from typing import Optional +from typing import Optional, Union +from urllib.parse import urlparse import httpx @@ -12,6 +13,9 @@ from litellm.utils import executor as thread_pool_executor from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, ) +from .llm_provider_handlers.assembly_passthrough_logging_handler import ( + AssemblyAIPassthroughLoggingHandler, +) from .llm_provider_handlers.vertex_passthrough_logging_handler import ( VertexPassthroughLoggingHandler, ) @@ -28,6 +32,48 @@ class PassThroughEndpointLogging: # Anthropic self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] + self.assemblyai_passthrough_logging_handler = ( + AssemblyAIPassthroughLoggingHandler() + ) + + async def _handle_logging( + self, + logging_obj: LiteLLMLoggingObj, + standard_logging_response_object: Union[ + StandardPassThroughResponseObject, + PassThroughEndpointLoggingResultValues, + dict, + ], + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + **kwargs, + ): + """Helper function to handle both sync and async logging operations""" + # Submit to thread pool for sync logging + thread_pool_executor.submit( + logging_obj.success_handler, + standard_logging_response_object, + start_time, + end_time, + cache_hit, + **kwargs, + ) + + # Handle async logging + await logging_obj.async_success_handler( + result=( + json.dumps(result) + if isinstance(result, dict) + else standard_logging_response_object + ), + start_time=start_time, + end_time=end_time, + cache_hit=False, + **kwargs, + ) + async def pass_through_async_success_handler( self, httpx_response: httpx.Response, @@ -79,28 +125,39 @@ class PassThroughEndpointLogging: anthropic_passthrough_logging_handler_result["result"] ) kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + elif self.is_assemblyai_route(url_route): + if ( + AssemblyAIPassthroughLoggingHandler._should_log_request( + httpx_response.request.method + ) + is not True + ): + return + self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + **kwargs, + ) + return + if standard_logging_response_object is None: standard_logging_response_object = StandardPassThroughResponseObject( response=httpx_response.text ) - thread_pool_executor.submit( - logging_obj.success_handler, - standard_logging_response_object, # Positional argument 1 - start_time, # Positional argument 2 - end_time, # Positional argument 3 - cache_hit, # Positional argument 4 - **kwargs, # Unpacked keyword arguments - ) - await logging_obj.async_success_handler( - result=( - json.dumps(result) - if isinstance(result, dict) - else standard_logging_response_object - ), + await self._handle_logging( + logging_obj=logging_obj, + standard_logging_response_object=standard_logging_response_object, + result=result, start_time=start_time, end_time=end_time, - cache_hit=False, + cache_hit=cache_hit, **kwargs, ) @@ -115,3 +172,11 @@ class PassThroughEndpointLogging: if route in url_route: return True return False + + def is_assemblyai_route(self, url_route: str): + parsed_url = urlparse(url_route) + if parsed_url.hostname == "api.assemblyai.com": + return True + elif "/transcript" in parsed_url.path: + return True + return False diff --git a/tests/pass_through_tests/test_assembly_ai.py b/tests/pass_through_tests/test_assembly_ai.py new file mode 100644 index 0000000000..dd9f91a0b4 --- /dev/null +++ b/tests/pass_through_tests/test_assembly_ai.py @@ -0,0 +1,95 @@ +""" +This test ensures that the proxy can passthrough anthropic requests +""" + +import pytest +import assemblyai as aai +import aiohttp +import asyncio +import time + +TEST_MASTER_KEY = "sk-1234" +TEST_BASE_URL = "http://0.0.0.0:4000/assemblyai" + + +def test_assemblyai_basic_transcribe(): + print("making basic transcribe request to assemblyai passthrough") + + # Replace with your API key + aai.settings.api_key = f"Bearer {TEST_MASTER_KEY}" + aai.settings.base_url = TEST_BASE_URL + + # URL of the file to transcribe + FILE_URL = "https://assembly.ai/wildfires.mp3" + + # You can also transcribe a local file by passing in a file path + # FILE_URL = './path/to/file.mp3' + + transcriber = aai.Transcriber() + transcript = transcriber.transcribe(FILE_URL) + print(transcript) + print(transcript.id) + if transcript.id: + transcript.delete_by_id(transcript.id) + else: + pytest.fail("Failed to get transcript id") + + if transcript.status == aai.TranscriptStatus.error: + print(transcript.error) + pytest.fail(f"Failed to transcribe file error: {transcript.error}") + else: + print(transcript.text) + + +async def generate_key(calling_key: str) -> str: + """Helper function to generate a new API key""" + url = "http://0.0.0.0:4000/key/generate" + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json={}) as response: + if response.status == 200: + data = await response.json() + return data.get("key") + raise Exception(f"Failed to generate key: {response.status}") + + +@pytest.mark.asyncio +async def test_assemblyai_transcribe_with_non_admin_key(): + # Generate a non-admin key using the helper + non_admin_key = await generate_key(TEST_MASTER_KEY) + print(f"Generated non-admin key: {non_admin_key}") + + # Use the non-admin key to transcribe + # Replace with your API key + aai.settings.api_key = f"Bearer {non_admin_key}" + aai.settings.base_url = TEST_BASE_URL + + # URL of the file to transcribe + FILE_URL = "https://assembly.ai/wildfires.mp3" + + # You can also transcribe a local file by passing in a file path + # FILE_URL = './path/to/file.mp3' + + request_start_time = time.time() + + transcriber = aai.Transcriber() + transcript = transcriber.transcribe(FILE_URL) + print(transcript) + print(transcript.id) + if transcript.id: + transcript.delete_by_id(transcript.id) + else: + pytest.fail("Failed to get transcript id") + + if transcript.status == aai.TranscriptStatus.error: + print(transcript.error) + pytest.fail(f"Failed to transcribe file error: {transcript.error}") + else: + print(transcript.text) + + request_end_time = time.time() + print(f"Request took {request_end_time - request_start_time} seconds") diff --git a/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py b/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py new file mode 100644 index 0000000000..71bae49671 --- /dev/null +++ b/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py @@ -0,0 +1,222 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + +import httpx +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import ( + AssemblyAIPassthroughLoggingHandler, + AssemblyAITranscriptResponse, +) +from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, +) + + +@pytest.fixture +def assembly_handler(): + handler = AssemblyAIPassthroughLoggingHandler() + handler.assemblyai_api_key = "test-key" + return handler + + +@pytest.fixture +def mock_transcript_response(): + return { + "id": "test-transcript-id", + "language_model": "default", + "acoustic_model": "default", + "language_code": "en", + "status": "completed", + "audio_duration": 100.0, + } + + +def test_should_log_request(): + handler = AssemblyAIPassthroughLoggingHandler() + assert handler._should_log_request("POST") == True + assert handler._should_log_request("GET") == False + + +def test_get_assembly_transcript(assembly_handler, mock_transcript_response): + """ + Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id} + """ + with patch("httpx.get") as mock_get: + mock_get.return_value.json.return_value = mock_transcript_response + mock_get.return_value.raise_for_status.return_value = None + + transcript = assembly_handler._get_assembly_transcript("test-transcript-id") + assert transcript == mock_transcript_response + + mock_get.assert_called_once_with( + "https://api.assemblyai.com/v2/transcript/test-transcript-id", + headers={ + "Authorization": "Bearer test-key", + "Content-Type": "application/json", + }, + ) + + +def test_poll_assembly_for_transcript_response( + assembly_handler, mock_transcript_response +): + """ + Test that the _poll_assembly_for_transcript_response method returns the correct transcript response + """ + with patch("httpx.get") as mock_get: + mock_get.return_value.json.return_value = mock_transcript_response + mock_get.return_value.raise_for_status.return_value = None + + # Override polling settings for faster test + assembly_handler.polling_interval = 0.01 + assembly_handler.max_polling_attempts = 2 + + transcript = assembly_handler._poll_assembly_for_transcript_response( + "test-transcript-id" + ) + assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response) + + +@pytest.fixture +def mock_request(): + request = Mock() + request.method = "POST" + request.headers = {} + request.url = httpx.URL("http://test.com/test") + return request + + +@pytest.fixture +def mock_response(): + return Mock() + + +@pytest.fixture +def mock_user_api_key_dict(): + return {"api_key": "test-key"} + + +@patch("litellm.utils.get_secret") +@patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route" +) +@pytest.mark.asyncio() +async def test_assemblyai_proxy_route_basic_post( + mock_create_route, + mock_get_secret, + mock_request, + mock_response, + mock_user_api_key_dict, +): + """Test basic POST request handling for AssemblyAI proxy route""" + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + assemblyai_proxy_route, + ) + + # Setup mocks + mock_get_secret.return_value = "test-assemblyai-key" + mock_request.json = AsyncMock(return_value={"text": "test"}) + mock_endpoint_func = AsyncMock(return_value={"result": "success"}) + mock_create_route.return_value = mock_endpoint_func + + result = await assemblyai_proxy_route( + endpoint="v2/transcript", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result == {"result": "success"} + mock_create_route.assert_called_once_with( + endpoint="v2/transcript", + target="https://api.assemblyai.com/v2/transcript", + custom_headers={"Authorization": "test-assemblyai-key"}, + ) + + +@patch("litellm.utils.get_secret") +@patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route" +) +@pytest.mark.asyncio() +async def test_assemblyai_proxy_route_get_transcript( + mock_create_route, + mock_get_secret, + mock_request, + mock_response, + mock_user_api_key_dict, +): + """Test GET request handling for retrieving a specific transcript from AssemblyAI""" + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + assemblyai_proxy_route, + ) + + # Setup mocks + mock_get_secret.return_value = "test-assemblyai-key" + mock_request.method = "GET" + mock_endpoint_func = AsyncMock( + return_value={"id": "test-transcript-id", "status": "completed"} + ) + mock_create_route.return_value = mock_endpoint_func + + result = await assemblyai_proxy_route( + endpoint="v2/transcript/test-transcript-id", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result == {"id": "test-transcript-id", "status": "completed"} + mock_create_route.assert_called_once_with( + endpoint="v2/transcript/test-transcript-id", + target="https://api.assemblyai.com/v2/transcript/test-transcript-id", + custom_headers={"Authorization": "test-assemblyai-key"}, + ) + + +def test_is_assemblyai_route(): + """ + Test that the is_assemblyai_route method correctly identifies AssemblyAI routes + """ + handler = PassThroughEndpointLogging() + + # Test positive cases + assert ( + handler.is_assemblyai_route("https://api.assemblyai.com/v2/transcript") == True + ) + assert handler.is_assemblyai_route("https://api.assemblyai.com/other/path") == True + assert handler.is_assemblyai_route("https://api.assemblyai.com/transcript") == True + + # Test negative cases + assert handler.is_assemblyai_route("https://example.com/other") == False + assert ( + handler.is_assemblyai_route("https://api.openai.com/v1/chat/completions") + == False + ) + assert handler.is_assemblyai_route("") == False