mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)
* add initial test for assembly ai * start using PassthroughEndpointRouter * migrate to lllm passthrough endpoints * add assembly ai as a known provider * fix PassthroughEndpointRouter * fix set_pass_through_credentials * working EU request to assembly ai pass through endpoint * add e2e test assembly * test_assemblyai_routes_with_bad_api_key * clean up pass through endpoint router * e2e testing for assembly ai pass through * test assembly ai e2e testing * delete assembly ai models * fix code quality * ui working assembly ai api base flow * fix install assembly ai * update model call details with kwargs for pass through logging * fix tracking assembly ai model in response * _handle_assemblyai_passthrough_logging * fix test_initialize_deployment_for_pass_through_unsupported_provider * TestPassthroughEndpointRouter * _get_assembly_transcript * fix assembly ai pt logging tests * fix assemblyai_proxy_route * fix _get_assembly_region_from_url
This commit is contained in:
parent
5dcb87a88b
commit
65c91cbbbc
13 changed files with 656 additions and 79 deletions
|
@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/azure",
|
||||
"/openai",
|
||||
"/assemblyai",
|
||||
"/eu.assemblyai",
|
||||
]
|
||||
|
||||
anthropic_routes = [
|
||||
|
|
|
@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
|
@ -68,8 +72,9 @@ async def gemini_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
|
||||
secret_name="GEMINI_API_KEY"
|
||||
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="gemini",
|
||||
region_name=None,
|
||||
)
|
||||
if gemini_api_key is None:
|
||||
raise Exception(
|
||||
|
@ -126,7 +131,10 @@ async def cohere_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
|
||||
cohere_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="cohere",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -175,7 +183,10 @@ async def anthropic_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
|
||||
anthropic_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="anthropic",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
|||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||
)
|
||||
@router.api_route(
|
||||
"/eu.assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI EU Pass-through", "pass-through"],
|
||||
)
|
||||
async def assemblyai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
"""
|
||||
[Docs](https://api.assemblyai.com)
|
||||
"""
|
||||
base_target_url = "https://api.assemblyai.com"
|
||||
# Set base URL based on the route
|
||||
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=str(request.url)
|
||||
)
|
||||
base_target_url = (
|
||||
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
|
||||
region=assembly_region
|
||||
)
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
@ -318,7 +345,10 @@ async def assemblyai_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
|
||||
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=assembly_region,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -366,7 +396,10 @@ async def azure_proxy_route(
|
|||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
# Add or update query parameters
|
||||
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
||||
azure_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="azure",
|
||||
region_name=None,
|
||||
)
|
||||
if azure_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||
|
@ -400,7 +433,10 @@ async def openai_proxy_route(
|
|||
"""
|
||||
base_target_url = "https://api.openai.com"
|
||||
# Add or update query parameters
|
||||
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY")
|
||||
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="openai",
|
||||
region_name=None,
|
||||
)
|
||||
if openai_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, TypedDict
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
|
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
|
|||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
language_model: str
|
||||
speech_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
|
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
|
|||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
|
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("model", "")
|
||||
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
|
||||
model = response_body.get("speech_model", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"response body %s", json.dumps(response_body, indent=4)
|
||||
)
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(
|
||||
transcript_id=transcript_id, url_route=url_route
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response",
|
||||
"finished polling assembly for transcript response- got transcript response %s",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(transcript_response)
|
||||
cost = self.get_cost_for_assembly_transcript(
|
||||
speech_model=model,
|
||||
transcript_response=transcript_response,
|
||||
)
|
||||
kwargs["response_cost"] = cost
|
||||
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
|
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
|
||||
def _get_assembly_transcript(
|
||||
self,
|
||||
transcript_id: str,
|
||||
request_region: Optional[Literal["eu"]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
|
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
_base_url = (
|
||||
self.assembly_ai_eu_base_url
|
||||
if request_region == "eu"
|
||||
else self.assembly_ai_base_url
|
||||
)
|
||||
_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=request_region,
|
||||
)
|
||||
if _api_key is None:
|
||||
raise ValueError("AssemblyAI API key not found")
|
||||
try:
|
||||
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
|
||||
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.assemblyai_api_key}",
|
||||
"Authorization": f"Bearer {_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self, transcript_id: str
|
||||
self,
|
||||
transcript_id: str,
|
||||
url_route: Optional[str] = None,
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
|
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(transcript_id)
|
||||
transcript = self._get_assembly_transcript(
|
||||
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=url_route
|
||||
),
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
|
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
speech_model: str,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
|
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
return _audio_duration * 0.0001
|
||||
_cost_per_second = (
|
||||
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||
speech_model=speech_model
|
||||
)
|
||||
)
|
||||
if _cost_per_second is None:
|
||||
return None
|
||||
return _audio_duration * _cost_per_second
|
||||
|
||||
@staticmethod
|
||||
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||
"""
|
||||
Get the cost per second for the assembly model.
|
||||
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||
"""
|
||||
try:
|
||||
# First try with the provided speech model
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=speech_model,
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass # Continue to fallback if model not found
|
||||
|
||||
# Fallback to assemblyai/nano if speech model info not found
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model="assemblyai/nano",
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
|
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||
"""
|
||||
Get the region from the URL
|
||||
"""
|
||||
if url is None:
|
||||
return None
|
||||
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||
"""
|
||||
Get the base URL for the AssemblyAI API
|
||||
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||
else return "https://api.assemblyai.com"
|
||||
"""
|
||||
if region == "eu":
|
||||
return "https://api.eu.assemblyai.com"
|
||||
return "https://api.assemblyai.com"
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class PassthroughEndpointRouter:
|
||||
"""
|
||||
Use this class to Set/Get credentials for pass-through endpoints
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.credentials: Dict[str, str] = {}
|
||||
|
||||
def set_pass_through_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
):
|
||||
"""
|
||||
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider of the pass-through endpoint
|
||||
api_base: The base URL of the pass-through endpoint
|
||||
api_key: The API key for the pass-through endpoint
|
||||
"""
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=self._get_region_name_from_api_base(
|
||||
api_base=api_base, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for setting pass-through credentials")
|
||||
self.credentials[credential_name] = api_key
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> Optional[str]:
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=region_name,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||
)
|
||||
if credential_name in self.credentials:
|
||||
verbose_logger.debug(f"Found credentials for {credential_name}")
|
||||
return self.credentials[credential_name]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"No credentials found for {credential_name}, looking for env variable"
|
||||
)
|
||||
_env_variable_name = (
|
||||
self._get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return get_secret_str(_env_variable_name)
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> str:
|
||||
if region_name is None:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
|
||||
|
||||
def _get_region_name_from_api_base(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the region name from the API base.
|
||||
|
||||
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
|
||||
"""
|
||||
if custom_llm_provider == "assemblyai":
|
||||
if api_base and "eu" in api_base:
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider: str,
|
||||
) -> str:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
Loading…
Add table
Add a link
Reference in a new issue