mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
(Feat) - New pass through add assembly ai passthrough endpoints (#8220)
* add assembly ai pass through request * fix assembly pass through * fix test_assemblyai_basic_transcribe * fix assemblyai auth check * test_assemblyai_transcribe_with_non_admin_key * working assembly ai test * working assembly ai proxy route * use helper func to pass through logging * clean up logging assembly ai * test: update test to handle gemini token counter change * fix(factory.py): fix bedrock http:// handling * add unit testing for assembly pt handler * docs assembly ai pass through endpoint * fix proxy_pass_through_endpoint_tests * fix standard_passthrough_logging_object * fix ASSEMBLYAI_API_KEY * test test_assemblyai_proxy_route_basic_post * test_assemblyai_proxy_route_get_transcript * fix is is_assemblyai_route * test_is_assemblyai_route --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
0ec5205711
commit
adeb09e091
9 changed files with 731 additions and 16 deletions
|
@ -1594,6 +1594,7 @@ jobs:
|
|||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.54.0 "
|
||||
pip install "assemblyai==0.37.0"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
@ -1626,6 +1627,7 @@ jobs:
|
|||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e GEMINI_API_KEY=$GEMINI_API_KEY \
|
||||
-e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \
|
||||
-e ASSEMBLYAI_API_KEY=$ASSEMBLYAI_API_KEY \
|
||||
-e USE_DDTRACE=True \
|
||||
-e DD_API_KEY=$DD_API_KEY \
|
||||
-e DD_SITE=$DD_SITE \
|
||||
|
|
55
docs/my-website/docs/pass_through/assembly_ai.md
Normal file
55
docs/my-website/docs/pass_through/assembly_ai.md
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Assembly AI
|
||||
|
||||
Pass-through endpoints for Assembly AI - call Assembly AI endpoints, in native format (no translation).
|
||||
|
||||
| Feature | Supported | Notes |
|
||||
|-------|-------|-------|
|
||||
| Cost Tracking | ✅ | works across all integrations |
|
||||
| Logging | ✅ | works across all integrations |
|
||||
|
||||
|
||||
Supports **ALL** Assembly AI Endpoints
|
||||
|
||||
[**See All Assembly AI Endpoints**](https://www.assemblyai.com/docs/api-reference)
|
||||
|
||||
## Quick Start
|
||||
|
||||
Let's call the Assembly AI [`/v2/transcripts` endpoint](https://www.assemblyai.com/docs/api-reference/transcripts)
|
||||
|
||||
1. Add Assembly AI API Key to your environment
|
||||
|
||||
```bash
|
||||
export ASSEMBLYAI_API_KEY=""
|
||||
```
|
||||
|
||||
2. Start LiteLLM Proxy
|
||||
|
||||
```bash
|
||||
litellm
|
||||
|
||||
# RUNNING on http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
Let's call the Assembly AI `/v2/transcripts` endpoint
|
||||
|
||||
```python
|
||||
LITELLM_VIRTUAL_KEY = "sk-1234" # <your-virtual-key>
|
||||
LITELLM_PROXY_BASE_URL = "http://0.0.0.0:4000/assemblyai" # <your-proxy-base-url>/assemblyai
|
||||
|
||||
aai.settings.api_key = f"Bearer {LITELLM_VIRTUAL_KEY}"
|
||||
aai.settings.base_url = LITELLM_PROXY_BASE_URL
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
```
|
||||
|
|
@ -305,6 +305,7 @@ const sidebars = {
|
|||
"pass_through/cohere",
|
||||
"pass_through/anthropic_completion",
|
||||
"pass_through/bedrock",
|
||||
"pass_through/assembly_ai",
|
||||
"pass_through/langfuse",
|
||||
"proxy/pass_through",
|
||||
],
|
||||
|
|
|
@ -247,6 +247,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/langfuse",
|
||||
"/azure",
|
||||
"/openai",
|
||||
"/assemblyai",
|
||||
]
|
||||
|
||||
anthropic_routes = [
|
||||
|
|
|
@ -292,6 +292,58 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||
)
|
||||
async def assemblyai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://api.assemblyai.com)
|
||||
"""
|
||||
base_target_url = "https://api.assemblyai.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# assemblyai is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "{}".format(assemblyai_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/azure/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload
|
||||
|
||||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
language_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
audio_duration: float
|
||||
|
||||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
||||
self.polling_interval: float = 10
|
||||
"""
|
||||
The polling interval for the AssemblyAI API.
|
||||
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
|
||||
"""
|
||||
|
||||
self.max_polling_attempts = 180
|
||||
"""
|
||||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
|
||||
"""
|
||||
executor.submit(
|
||||
self._handle_assemblyai_passthrough_logging,
|
||||
httpx_response,
|
||||
response_body,
|
||||
logging_obj,
|
||||
url_route,
|
||||
result,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _handle_assemblyai_passthrough_logging(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Handles logging for AssemblyAI successful passthrough requests
|
||||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("model", "")
|
||||
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(transcript_response)
|
||||
kwargs["response_cost"] = cost
|
||||
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=transcript_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_passthrough_logging_object %s",
|
||||
json.dumps(passthrough_logging_payload, indent=4),
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
|
||||
)
|
||||
asyncio.run(
|
||||
pass_through_endpoint_logging._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=self._get_response_to_log(
|
||||
transcript_response
|
||||
),
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def _get_response_to_log(
|
||||
self, transcript_response: Optional[AssemblyAITranscriptResponse]
|
||||
) -> dict:
|
||||
if transcript_response is None:
|
||||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
Args:
|
||||
response_body (dict): Response containing the transcript ID
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.assemblyai_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = httpx.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self, transcript_id: str
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
"""
|
||||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(transcript_id)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
transcript.get("status") == "completed"
|
||||
or transcript.get("status") == "error"
|
||||
):
|
||||
return AssemblyAITranscriptResponse(**transcript)
|
||||
time.sleep(self.polling_interval)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
"""
|
||||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
return _audio_duration * 0.0001
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
"""
|
||||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -12,6 +13,9 @@ from litellm.utils import executor as thread_pool_executor
|
|||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
|
@ -28,6 +32,48 @@ class PassThroughEndpointLogging:
|
|||
# Anthropic
|
||||
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
|
||||
|
||||
self.assemblyai_passthrough_logging_handler = (
|
||||
AssemblyAIPassthroughLoggingHandler()
|
||||
)
|
||||
|
||||
async def _handle_logging(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
standard_logging_response_object: Union[
|
||||
StandardPassThroughResponseObject,
|
||||
PassThroughEndpointLoggingResultValues,
|
||||
dict,
|
||||
],
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""Helper function to handle both sync and async logging operations"""
|
||||
# Submit to thread pool for sync logging
|
||||
thread_pool_executor.submit(
|
||||
logging_obj.success_handler,
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Handle async logging
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def pass_through_async_success_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
|
@ -79,28 +125,39 @@ class PassThroughEndpointLogging:
|
|||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_assemblyai_route(url_route):
|
||||
if (
|
||||
AssemblyAIPassthroughLoggingHandler._should_log_request(
|
||||
httpx_response.request.method
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=httpx_response.text
|
||||
)
|
||||
thread_pool_executor.submit(
|
||||
logging_obj.success_handler,
|
||||
standard_logging_response_object, # Positional argument 1
|
||||
start_time, # Positional argument 2
|
||||
end_time, # Positional argument 3
|
||||
cache_hit, # Positional argument 4
|
||||
**kwargs, # Unpacked keyword arguments
|
||||
)
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
await self._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=standard_logging_response_object,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -115,3 +172,11 @@ class PassThroughEndpointLogging:
|
|||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_assemblyai_route(self, url_route: str):
|
||||
parsed_url = urlparse(url_route)
|
||||
if parsed_url.hostname == "api.assemblyai.com":
|
||||
return True
|
||||
elif "/transcript" in parsed_url.path:
|
||||
return True
|
||||
return False
|
||||
|
|
95
tests/pass_through_tests/test_assembly_ai.py
Normal file
95
tests/pass_through_tests/test_assembly_ai.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
This test ensures that the proxy can passthrough anthropic requests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import assemblyai as aai
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
TEST_BASE_URL = "http://0.0.0.0:4000/assemblyai"
|
||||
|
||||
|
||||
def test_assemblyai_basic_transcribe():
|
||||
print("making basic transcribe request to assemblyai passthrough")
|
||||
|
||||
# Replace with your API key
|
||||
aai.settings.api_key = f"Bearer {TEST_MASTER_KEY}"
|
||||
aai.settings.base_url = TEST_BASE_URL
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
if transcript.id:
|
||||
transcript.delete_by_id(transcript.id)
|
||||
else:
|
||||
pytest.fail("Failed to get transcript id")
|
||||
|
||||
if transcript.status == aai.TranscriptStatus.error:
|
||||
print(transcript.error)
|
||||
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
|
||||
else:
|
||||
print(transcript.text)
|
||||
|
||||
|
||||
async def generate_key(calling_key: str) -> str:
|
||||
"""Helper function to generate a new API key"""
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {calling_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json={}) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("key")
|
||||
raise Exception(f"Failed to generate key: {response.status}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assemblyai_transcribe_with_non_admin_key():
|
||||
# Generate a non-admin key using the helper
|
||||
non_admin_key = await generate_key(TEST_MASTER_KEY)
|
||||
print(f"Generated non-admin key: {non_admin_key}")
|
||||
|
||||
# Use the non-admin key to transcribe
|
||||
# Replace with your API key
|
||||
aai.settings.api_key = f"Bearer {non_admin_key}"
|
||||
aai.settings.base_url = TEST_BASE_URL
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
request_start_time = time.time()
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
if transcript.id:
|
||||
transcript.delete_by_id(transcript.id)
|
||||
else:
|
||||
pytest.fail("Failed to get transcript id")
|
||||
|
||||
if transcript.status == aai.TranscriptStatus.error:
|
||||
print(transcript.error)
|
||||
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
|
||||
else:
|
||||
print(transcript.text)
|
||||
|
||||
request_end_time = time.time()
|
||||
print(f"Request took {request_end_time - request_start_time} seconds")
|
|
@ -0,0 +1,222 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
AssemblyAITranscriptResponse,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def assembly_handler():
|
||||
handler = AssemblyAIPassthroughLoggingHandler()
|
||||
handler.assemblyai_api_key = "test-key"
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transcript_response():
|
||||
return {
|
||||
"id": "test-transcript-id",
|
||||
"language_model": "default",
|
||||
"acoustic_model": "default",
|
||||
"language_code": "en",
|
||||
"status": "completed",
|
||||
"audio_duration": 100.0,
|
||||
}
|
||||
|
||||
|
||||
def test_should_log_request():
|
||||
handler = AssemblyAIPassthroughLoggingHandler()
|
||||
assert handler._should_log_request("POST") == True
|
||||
assert handler._should_log_request("GET") == False
|
||||
|
||||
|
||||
def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
|
||||
"""
|
||||
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
|
||||
"""
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value.json.return_value = mock_transcript_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
|
||||
transcript = assembly_handler._get_assembly_transcript("test-transcript-id")
|
||||
assert transcript == mock_transcript_response
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.assemblyai.com/v2/transcript/test-transcript-id",
|
||||
headers={
|
||||
"Authorization": "Bearer test-key",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_poll_assembly_for_transcript_response(
|
||||
assembly_handler, mock_transcript_response
|
||||
):
|
||||
"""
|
||||
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
|
||||
"""
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value.json.return_value = mock_transcript_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
|
||||
# Override polling settings for faster test
|
||||
assembly_handler.polling_interval = 0.01
|
||||
assembly_handler.max_polling_attempts = 2
|
||||
|
||||
transcript = assembly_handler._poll_assembly_for_transcript_response(
|
||||
"test-transcript-id"
|
||||
)
|
||||
assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
request = Mock()
|
||||
request.method = "POST"
|
||||
request.headers = {}
|
||||
request.url = httpx.URL("http://test.com/test")
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_api_key_dict():
|
||||
return {"api_key": "test-key"}
|
||||
|
||||
|
||||
@patch("litellm.utils.get_secret")
|
||||
@patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route"
|
||||
)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_assemblyai_proxy_route_basic_post(
|
||||
mock_create_route,
|
||||
mock_get_secret,
|
||||
mock_request,
|
||||
mock_response,
|
||||
mock_user_api_key_dict,
|
||||
):
|
||||
"""Test basic POST request handling for AssemblyAI proxy route"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
assemblyai_proxy_route,
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_get_secret.return_value = "test-assemblyai-key"
|
||||
mock_request.json = AsyncMock(return_value={"text": "test"})
|
||||
mock_endpoint_func = AsyncMock(return_value={"result": "success"})
|
||||
mock_create_route.return_value = mock_endpoint_func
|
||||
|
||||
result = await assemblyai_proxy_route(
|
||||
endpoint="v2/transcript",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint="v2/transcript",
|
||||
target="https://api.assemblyai.com/v2/transcript",
|
||||
custom_headers={"Authorization": "test-assemblyai-key"},
|
||||
)
|
||||
|
||||
|
||||
@patch("litellm.utils.get_secret")
|
||||
@patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.create_pass_through_route"
|
||||
)
|
||||
@pytest.mark.asyncio()
|
||||
async def test_assemblyai_proxy_route_get_transcript(
|
||||
mock_create_route,
|
||||
mock_get_secret,
|
||||
mock_request,
|
||||
mock_response,
|
||||
mock_user_api_key_dict,
|
||||
):
|
||||
"""Test GET request handling for retrieving a specific transcript from AssemblyAI"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
assemblyai_proxy_route,
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_get_secret.return_value = "test-assemblyai-key"
|
||||
mock_request.method = "GET"
|
||||
mock_endpoint_func = AsyncMock(
|
||||
return_value={"id": "test-transcript-id", "status": "completed"}
|
||||
)
|
||||
mock_create_route.return_value = mock_endpoint_func
|
||||
|
||||
result = await assemblyai_proxy_route(
|
||||
endpoint="v2/transcript/test-transcript-id",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
assert result == {"id": "test-transcript-id", "status": "completed"}
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint="v2/transcript/test-transcript-id",
|
||||
target="https://api.assemblyai.com/v2/transcript/test-transcript-id",
|
||||
custom_headers={"Authorization": "test-assemblyai-key"},
|
||||
)
|
||||
|
||||
|
||||
def test_is_assemblyai_route():
|
||||
"""
|
||||
Test that the is_assemblyai_route method correctly identifies AssemblyAI routes
|
||||
"""
|
||||
handler = PassThroughEndpointLogging()
|
||||
|
||||
# Test positive cases
|
||||
assert (
|
||||
handler.is_assemblyai_route("https://api.assemblyai.com/v2/transcript") == True
|
||||
)
|
||||
assert handler.is_assemblyai_route("https://api.assemblyai.com/other/path") == True
|
||||
assert handler.is_assemblyai_route("https://api.assemblyai.com/transcript") == True
|
||||
|
||||
# Test negative cases
|
||||
assert handler.is_assemblyai_route("https://example.com/other") == False
|
||||
assert (
|
||||
handler.is_assemblyai_route("https://api.openai.com/v1/chat/completions")
|
||||
== False
|
||||
)
|
||||
assert handler.is_assemblyai_route("") == False
|
Loading…
Add table
Add a link
Reference in a new issue