(Feat) - New pass through add assembly ai passthrough endpoints (#8220)

* add assembly ai pass through request

* fix assembly pass through

* fix test_assemblyai_basic_transcribe

* fix assemblyai auth check

* test_assemblyai_transcribe_with_non_admin_key

* working assembly ai test

* working assembly ai proxy route

* use helper func to pass through logging

* clean up logging assembly ai

* test: update test to handle gemini token counter change

* fix(factory.py): fix bedrock http:// handling

* add unit testing for assembly pt handler

* docs assembly ai pass through endpoint

* fix proxy_pass_through_endpoint_tests

* fix standard_passthrough_logging_object

* fix ASSEMBLYAI_API_KEY

* test test_assemblyai_proxy_route_basic_post

* test_assemblyai_proxy_route_get_transcript

* fix is is_assemblyai_route

* test_is_assemblyai_route

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2025-02-03 21:54:32 -08:00 committed by GitHub
parent 0ec5205711
commit adeb09e091
9 changed files with 731 additions and 16 deletions

View file

@ -1,6 +1,7 @@
import json
from datetime import datetime
from typing import Optional
from typing import Optional, Union
from urllib.parse import urlparse
import httpx
@ -12,6 +13,9 @@ from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
@ -28,6 +32,48 @@ class PassThroughEndpointLogging:
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
@ -79,28 +125,39 @@ class PassThroughEndpointLogging:
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object, # Positional argument 1
start_time, # Positional argument 2
end_time, # Positional argument 3
cache_hit, # Positional argument 4
**kwargs, # Unpacked keyword arguments
)
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
await self._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=standard_logging_response_object,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=False,
cache_hit=cache_hit,
**kwargs,
)
@ -115,3 +172,11 @@ class PassThroughEndpointLogging:
if route in url_route:
return True
return False
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False