mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)
* add initial test for assembly ai * start using PassthroughEndpointRouter * migrate to lllm passthrough endpoints * add assembly ai as a known provider * fix PassthroughEndpointRouter * fix set_pass_through_credentials * working EU request to assembly ai pass through endpoint * add e2e test assembly * test_assemblyai_routes_with_bad_api_key * clean up pass through endpoint router * e2e testing for assembly ai pass through * test assembly ai e2e testing * delete assembly ai models * fix code quality * ui working assembly ai api base flow * fix install assembly ai * update model call details with kwargs for pass through logging * fix tracking assembly ai model in response * _handle_assemblyai_passthrough_logging * fix test_initialize_deployment_for_pass_through_unsupported_provider * TestPassthroughEndpointRouter * _get_assembly_transcript * fix assembly ai pt logging tests * fix assemblyai_proxy_route * fix _get_assembly_region_from_url
This commit is contained in:
parent
5dcb87a88b
commit
65c91cbbbc
13 changed files with 656 additions and 79 deletions
|
@ -1552,6 +1552,7 @@ jobs:
|
||||||
pip install "pytest-retry==1.6.3"
|
pip install "pytest-retry==1.6.3"
|
||||||
pip install "pytest-mock==3.12.0"
|
pip install "pytest-mock==3.12.0"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
|
pip install "assemblyai==0.37.0"
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||||
|
|
|
@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
"/azure",
|
"/azure",
|
||||||
"/openai",
|
"/openai",
|
||||||
"/assemblyai",
|
"/assemblyai",
|
||||||
|
"/eu.assemblyai",
|
||||||
]
|
]
|
||||||
|
|
||||||
anthropic_routes = [
|
anthropic_routes = [
|
||||||
|
|
|
@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
default_vertex_config = None
|
default_vertex_config = None
|
||||||
|
|
||||||
|
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
|
||||||
def create_request_copy(request: Request):
|
def create_request_copy(request: Request):
|
||||||
return {
|
return {
|
||||||
|
@ -68,8 +72,9 @@ async def gemini_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
|
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
|
||||||
secret_name="GEMINI_API_KEY"
|
custom_llm_provider="gemini",
|
||||||
|
region_name=None,
|
||||||
)
|
)
|
||||||
if gemini_api_key is None:
|
if gemini_api_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -126,7 +131,10 @@ async def cohere_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
|
cohere_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="cohere",
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
## check for streaming
|
## check for streaming
|
||||||
is_streaming_request = False
|
is_streaming_request = False
|
||||||
|
@ -175,7 +183,10 @@ async def anthropic_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
|
anthropic_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
## check for streaming
|
## check for streaming
|
||||||
is_streaming_request = False
|
is_streaming_request = False
|
||||||
|
@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||||
)
|
)
|
||||||
|
@router.api_route(
|
||||||
|
"/eu.assemblyai/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
tags=["AssemblyAI EU Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
async def assemblyai_proxy_route(
|
async def assemblyai_proxy_route(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||||
|
AssemblyAIPassthroughLoggingHandler,
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
[Docs](https://api.assemblyai.com)
|
[Docs](https://api.assemblyai.com)
|
||||||
"""
|
"""
|
||||||
base_target_url = "https://api.assemblyai.com"
|
# Set base URL based on the route
|
||||||
|
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||||
|
url=str(request.url)
|
||||||
|
)
|
||||||
|
base_target_url = (
|
||||||
|
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
|
||||||
|
region=assembly_region
|
||||||
|
)
|
||||||
|
)
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
# Ensure endpoint starts with '/' for proper URL construction
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
if not encoded_endpoint.startswith("/"):
|
if not encoded_endpoint.startswith("/"):
|
||||||
encoded_endpoint = "/" + encoded_endpoint
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
@ -318,7 +345,10 @@ async def assemblyai_proxy_route(
|
||||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
|
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="assemblyai",
|
||||||
|
region_name=assembly_region,
|
||||||
|
)
|
||||||
|
|
||||||
## check for streaming
|
## check for streaming
|
||||||
is_streaming_request = False
|
is_streaming_request = False
|
||||||
|
@ -366,7 +396,10 @@ async def azure_proxy_route(
|
||||||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||||
)
|
)
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
azure_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="azure",
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
if azure_api_key is None:
|
if azure_api_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||||
|
@ -400,7 +433,10 @@ async def openai_proxy_route(
|
||||||
"""
|
"""
|
||||||
base_target_url = "https://api.openai.com"
|
base_target_url = "https://api.openai.com"
|
||||||
# Add or update query parameters
|
# Add or update query parameters
|
||||||
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY")
|
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
region_name=None,
|
||||||
|
)
|
||||||
if openai_api_key is None:
|
if openai_api_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, TypedDict
|
from typing import Literal, Optional, TypedDict
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.litellm_core_utils.litellm_logging import (
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
|
||||||
|
|
||||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||||
id: str
|
id: str
|
||||||
language_model: str
|
speech_model: str
|
||||||
acoustic_model: str
|
acoustic_model: str
|
||||||
language_code: str
|
language_code: str
|
||||||
status: str
|
status: str
|
||||||
|
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||||
|
|
||||||
class AssemblyAIPassthroughLoggingHandler:
|
class AssemblyAIPassthroughLoggingHandler:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
|
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||||
|
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||||
"""
|
"""
|
||||||
The base URL for the AssemblyAI API
|
The base URL for the AssemblyAI API
|
||||||
"""
|
"""
|
||||||
|
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
The maximum number of polling attempts for the AssemblyAI API.
|
The maximum number of polling attempts for the AssemblyAI API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
|
|
||||||
|
|
||||||
def assemblyai_passthrough_logging_handler(
|
def assemblyai_passthrough_logging_handler(
|
||||||
self,
|
self,
|
||||||
httpx_response: httpx.Response,
|
httpx_response: httpx.Response,
|
||||||
|
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
"""
|
"""
|
||||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||||
|
|
||||||
model = response_body.get("model", "")
|
model = response_body.get("speech_model", "")
|
||||||
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
|
verbose_proxy_logger.debug(
|
||||||
|
"response body %s", json.dumps(response_body, indent=4)
|
||||||
|
)
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["custom_llm_provider"] = "assemblyai"
|
kwargs["custom_llm_provider"] = "assemblyai"
|
||||||
|
logging_obj.model_call_details["model"] = model
|
||||||
|
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||||
|
|
||||||
transcript_id = response_body.get("id")
|
transcript_id = response_body.get("id")
|
||||||
if transcript_id is None:
|
if transcript_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Transcript ID is required to log the cost of the transcription"
|
"Transcript ID is required to log the cost of the transcription"
|
||||||
)
|
)
|
||||||
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
|
transcript_response = self._poll_assembly_for_transcript_response(
|
||||||
|
transcript_id=transcript_id, url_route=url_route
|
||||||
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"finished polling assembly for transcript response- got transcript response",
|
"finished polling assembly for transcript response- got transcript response %s",
|
||||||
json.dumps(transcript_response, indent=4),
|
json.dumps(transcript_response, indent=4),
|
||||||
)
|
)
|
||||||
if transcript_response:
|
if transcript_response:
|
||||||
cost = self.get_cost_for_assembly_transcript(transcript_response)
|
cost = self.get_cost_for_assembly_transcript(
|
||||||
|
speech_model=model,
|
||||||
|
transcript_response=transcript_response,
|
||||||
|
)
|
||||||
kwargs["response_cost"] = cost
|
kwargs["response_cost"] = cost
|
||||||
|
|
||||||
logging_obj.model_call_details["model"] = logging_obj.model
|
|
||||||
|
|
||||||
# Make standard logging object for Vertex AI
|
# Make standard logging object for Vertex AI
|
||||||
standard_logging_object = get_standard_logging_object_payload(
|
standard_logging_object = get_standard_logging_object_payload(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
return {}
|
return {}
|
||||||
return dict(transcript_response)
|
return dict(transcript_response)
|
||||||
|
|
||||||
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
|
def _get_assembly_transcript(
|
||||||
|
self,
|
||||||
|
transcript_id: str,
|
||||||
|
request_region: Optional[Literal["eu"]] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get the transcript details from AssemblyAI API
|
Get the transcript details from AssemblyAI API
|
||||||
|
|
||||||
|
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
Returns:
|
Returns:
|
||||||
Optional[dict]: Transcript details if successful, None otherwise
|
Optional[dict]: Transcript details if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
|
passthrough_endpoint_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
_base_url = (
|
||||||
|
self.assembly_ai_eu_base_url
|
||||||
|
if request_region == "eu"
|
||||||
|
else self.assembly_ai_base_url
|
||||||
|
)
|
||||||
|
_api_key = passthrough_endpoint_router.get_credentials(
|
||||||
|
custom_llm_provider="assemblyai",
|
||||||
|
region_name=request_region,
|
||||||
|
)
|
||||||
|
if _api_key is None:
|
||||||
|
raise ValueError("AssemblyAI API key not found")
|
||||||
try:
|
try:
|
||||||
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
|
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.assemblyai_api_key}",
|
"Authorization": f"Bearer {_api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
|
verbose_proxy_logger.exception(
|
||||||
|
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _poll_assembly_for_transcript_response(
|
def _poll_assembly_for_transcript_response(
|
||||||
self, transcript_id: str
|
self,
|
||||||
|
transcript_id: str,
|
||||||
|
url_route: Optional[str] = None,
|
||||||
) -> Optional[AssemblyAITranscriptResponse]:
|
) -> Optional[AssemblyAITranscriptResponse]:
|
||||||
"""
|
"""
|
||||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||||
|
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
for _ in range(
|
for _ in range(
|
||||||
self.max_polling_attempts
|
self.max_polling_attempts
|
||||||
): # 180 attempts * 10s = 30 minutes max
|
): # 180 attempts * 10s = 30 minutes max
|
||||||
transcript = self._get_assembly_transcript(transcript_id)
|
transcript = self._get_assembly_transcript(
|
||||||
|
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||||
|
url=url_route
|
||||||
|
),
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
if transcript is None:
|
if transcript is None:
|
||||||
return None
|
return None
|
||||||
if (
|
if (
|
||||||
|
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cost_for_assembly_transcript(
|
def get_cost_for_assembly_transcript(
|
||||||
transcript_response: AssemblyAITranscriptResponse,
|
transcript_response: AssemblyAITranscriptResponse,
|
||||||
|
speech_model: str,
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Get the cost for the assembly transcript
|
Get the cost for the assembly transcript
|
||||||
|
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
_audio_duration = transcript_response.get("audio_duration")
|
_audio_duration = transcript_response.get("audio_duration")
|
||||||
if _audio_duration is None:
|
if _audio_duration is None:
|
||||||
return None
|
return None
|
||||||
return _audio_duration * 0.0001
|
_cost_per_second = (
|
||||||
|
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||||
|
speech_model=speech_model
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if _cost_per_second is None:
|
||||||
|
return None
|
||||||
|
return _audio_duration * _cost_per_second
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Get the cost per second for the assembly model.
|
||||||
|
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First try with the provided speech model
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=speech_model,
|
||||||
|
custom_llm_provider="assemblyai",
|
||||||
|
)
|
||||||
|
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||||
|
return model_info.get("input_cost_per_second")
|
||||||
|
except Exception:
|
||||||
|
pass # Continue to fallback if model not found
|
||||||
|
|
||||||
|
# Fallback to assemblyai/nano if speech model info not found
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model="assemblyai/nano",
|
||||||
|
custom_llm_provider="assemblyai",
|
||||||
|
)
|
||||||
|
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||||
|
return model_info.get("input_cost_per_second")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_log_request(request_method: str) -> bool:
|
def _should_log_request(request_method: str) -> bool:
|
||||||
|
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
||||||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||||
"""
|
"""
|
||||||
return request_method == "POST"
|
return request_method == "POST"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||||
|
"""
|
||||||
|
Get the region from the URL
|
||||||
|
"""
|
||||||
|
if url is None:
|
||||||
|
return None
|
||||||
|
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||||
|
return "eu"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||||
|
"""
|
||||||
|
Get the base URL for the AssemblyAI API
|
||||||
|
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||||
|
else return "https://api.assemblyai.com"
|
||||||
|
"""
|
||||||
|
if region == "eu":
|
||||||
|
return "https://api.eu.assemblyai.com"
|
||||||
|
return "https://api.assemblyai.com"
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
class PassthroughEndpointRouter:
|
||||||
|
"""
|
||||||
|
Use this class to Set/Get credentials for pass-through endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.credentials: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def set_pass_through_credentials(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
custom_llm_provider: The provider of the pass-through endpoint
|
||||||
|
api_base: The base URL of the pass-through endpoint
|
||||||
|
api_key: The API key for the pass-through endpoint
|
||||||
|
"""
|
||||||
|
credential_name = self._get_credential_name_for_provider(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
region_name=self._get_region_name_from_api_base(
|
||||||
|
api_base=api_base, custom_llm_provider=custom_llm_provider
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError("api_key is required for setting pass-through credentials")
|
||||||
|
self.credentials[credential_name] = api_key
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
region_name: Optional[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
credential_name = self._get_credential_name_for_provider(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||||
|
)
|
||||||
|
if credential_name in self.credentials:
|
||||||
|
verbose_logger.debug(f"Found credentials for {credential_name}")
|
||||||
|
return self.credentials[credential_name]
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"No credentials found for {credential_name}, looking for env variable"
|
||||||
|
)
|
||||||
|
_env_variable_name = (
|
||||||
|
self._get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return get_secret_str(_env_variable_name)
|
||||||
|
|
||||||
|
def _get_credential_name_for_provider(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
region_name: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
if region_name is None:
|
||||||
|
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||||
|
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
|
||||||
|
|
||||||
|
def _get_region_name_from_api_base(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the region name from the API base.
|
||||||
|
|
||||||
|
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "assemblyai":
|
||||||
|
if api_base and "eu" in api_base:
|
||||||
|
return "eu"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
custom_llm_provider: str,
|
||||||
|
) -> str:
|
||||||
|
return f"{custom_llm_provider.upper()}_API_KEY"
|
|
@ -4180,8 +4180,14 @@ class Router:
|
||||||
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
verbose_router_logger.error(
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints"
|
passthrough_endpoint_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
passthrough_endpoint_router.set_pass_through_credentials(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=deployment.litellm_params.api_base,
|
||||||
|
api_key=deployment.litellm_params.api_key,
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1870,6 +1870,7 @@ class LlmProviders(str, Enum):
|
||||||
LANGFUSE = "langfuse"
|
LANGFUSE = "langfuse"
|
||||||
HUMANLOOP = "humanloop"
|
HUMANLOOP = "humanloop"
|
||||||
TOPAZ = "topaz"
|
TOPAZ = "topaz"
|
||||||
|
ASSEMBLYAI = "assemblyai"
|
||||||
|
|
||||||
|
|
||||||
# Create a set of all provider values for quick lookup
|
# Create a set of all provider values for quick lookup
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
This test ensures that the proxy can passthrough anthropic requests
|
This test ensures that the proxy can passthrough requests to assemblyai
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -41,7 +41,6 @@ from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def assembly_handler():
|
def assembly_handler():
|
||||||
handler = AssemblyAIPassthroughLoggingHandler()
|
handler = AssemblyAIPassthroughLoggingHandler()
|
||||||
handler.assemblyai_api_key = "test-key"
|
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +65,13 @@ def test_should_log_request():
|
||||||
def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
|
def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
|
||||||
"""
|
"""
|
||||||
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
|
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
|
||||||
|
and uses the test key returned by the mocked get_credentials.
|
||||||
"""
|
"""
|
||||||
|
# Patch get_credentials to return "test-key"
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
|
||||||
|
return_value="test-key",
|
||||||
|
):
|
||||||
with patch("httpx.get") as mock_get:
|
with patch("httpx.get") as mock_get:
|
||||||
mock_get.return_value.json.return_value = mock_transcript_response
|
mock_get.return_value.json.return_value = mock_transcript_response
|
||||||
mock_get.return_value.raise_for_status.return_value = None
|
mock_get.return_value.raise_for_status.return_value = None
|
||||||
|
@ -89,6 +94,10 @@ def test_poll_assembly_for_transcript_response(
|
||||||
"""
|
"""
|
||||||
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
|
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
|
||||||
"""
|
"""
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
|
||||||
|
return_value="test-key",
|
||||||
|
):
|
||||||
with patch("httpx.get") as mock_get:
|
with patch("httpx.get") as mock_get:
|
||||||
mock_get.return_value.json.return_value = mock_transcript_response
|
mock_get.return_value.json.return_value = mock_transcript_response
|
||||||
mock_get.return_value.raise_for_status.return_value = None
|
mock_get.return_value.raise_for_status.return_value = None
|
||||||
|
@ -98,9 +107,11 @@ def test_poll_assembly_for_transcript_response(
|
||||||
assembly_handler.max_polling_attempts = 2
|
assembly_handler.max_polling_attempts = 2
|
||||||
|
|
||||||
transcript = assembly_handler._poll_assembly_for_transcript_response(
|
transcript = assembly_handler._poll_assembly_for_transcript_response(
|
||||||
"test-transcript-id"
|
"test-transcript-id",
|
||||||
|
)
|
||||||
|
assert transcript == AssemblyAITranscriptResponse(
|
||||||
|
**mock_transcript_response
|
||||||
)
|
)
|
||||||
assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_assemblyai_route():
|
def test_is_assemblyai_route():
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../..")) #
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
|
PassthroughEndpointRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Basic Usage
|
||||||
|
- Set OpenAI, AssemblyAI, Anthropic, Cohere credentials
|
||||||
|
- GET credentials from passthrough_endpoint_router
|
||||||
|
|
||||||
|
2. Basic Usage - when not using DB
|
||||||
|
- No credentials set
|
||||||
|
- call GET credentials with provider name, assert that it reads the secret from the environment variable
|
||||||
|
|
||||||
|
|
||||||
|
3. Unit test for _get_default_env_variable_name_passthrough_endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestPassthroughEndpointRouter(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
def test_set_and_get_credentials(self):
|
||||||
|
"""
|
||||||
|
1. Basic Usage:
|
||||||
|
- Set credentials for OpenAI, AssemblyAI, Anthropic, Cohere
|
||||||
|
- GET credentials from passthrough_endpoint_router (from the memory store when available)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# OpenAI: standard (no region-specific logic)
|
||||||
|
self.router.set_pass_through_credentials("openai", None, "openai_key")
|
||||||
|
self.assertEqual(self.router.get_credentials("openai", None), "openai_key")
|
||||||
|
|
||||||
|
# AssemblyAI: using an API base that contains 'eu' should trigger regional logic.
|
||||||
|
api_base_eu = "https://api.eu.assemblyai.com"
|
||||||
|
self.router.set_pass_through_credentials(
|
||||||
|
"assemblyai", api_base_eu, "assemblyai_key"
|
||||||
|
)
|
||||||
|
# When calling get_credentials, pass the region "eu" (extracted from the API base)
|
||||||
|
self.assertEqual(
|
||||||
|
self.router.get_credentials("assemblyai", "eu"), "assemblyai_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Anthropic: no region set
|
||||||
|
self.router.set_pass_through_credentials("anthropic", None, "anthropic_key")
|
||||||
|
self.assertEqual(
|
||||||
|
self.router.get_credentials("anthropic", None), "anthropic_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cohere: no region set
|
||||||
|
self.router.set_pass_through_credentials("cohere", None, "cohere_key")
|
||||||
|
self.assertEqual(self.router.get_credentials("cohere", None), "cohere_key")
|
||||||
|
|
||||||
|
def test_get_credentials_from_env(self):
|
||||||
|
"""
|
||||||
|
2. Basic Usage - when not using the database:
|
||||||
|
- No credentials set in memory
|
||||||
|
- Call get_credentials with provider name and expect it to read from the environment variable (via get_secret_str)
|
||||||
|
"""
|
||||||
|
# Patch the get_secret_str function within the router's module.
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||||
|
) as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = "env_openai_key"
|
||||||
|
# For "openai", if credentials are not set, it should fallback to the env variable.
|
||||||
|
result = self.router.get_credentials("openai", None)
|
||||||
|
self.assertEqual(result, "env_openai_key")
|
||||||
|
mock_get_secret.assert_called_once_with("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||||
|
) as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = "env_cohere_key"
|
||||||
|
result = self.router.get_credentials("cohere", None)
|
||||||
|
self.assertEqual(result, "env_cohere_key")
|
||||||
|
mock_get_secret.assert_called_once_with("COHERE_API_KEY")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||||
|
) as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = "env_anthropic_key"
|
||||||
|
result = self.router.get_credentials("anthropic", None)
|
||||||
|
self.assertEqual(result, "env_anthropic_key")
|
||||||
|
mock_get_secret.assert_called_once_with("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||||
|
) as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = "env_azure_key"
|
||||||
|
result = self.router.get_credentials("azure", None)
|
||||||
|
self.assertEqual(result, "env_azure_key")
|
||||||
|
mock_get_secret.assert_called_once_with("AZURE_API_KEY")
|
||||||
|
|
||||||
|
def test_default_env_variable_method(self):
|
||||||
|
"""
|
||||||
|
3. Unit test for _get_default_env_variable_name_passthrough_endpoint:
|
||||||
|
- Should return the provider in uppercase followed by _API_KEY.
|
||||||
|
"""
|
||||||
|
self.assertEqual(
|
||||||
|
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
"openai"
|
||||||
|
),
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
"assemblyai"
|
||||||
|
),
|
||||||
|
"ASSEMBLYAI_API_KEY",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
"anthropic"
|
||||||
|
),
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||||
|
"cohere"
|
||||||
|
),
|
||||||
|
"COHERE_API_KEY",
|
||||||
|
)
|
|
@ -76,27 +76,6 @@ def test_initialize_deployment_for_pass_through_missing_params():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_deployment_for_pass_through_unsupported_provider():
|
|
||||||
"""
|
|
||||||
Test initialization with an unsupported provider
|
|
||||||
"""
|
|
||||||
router = Router(model_list=[])
|
|
||||||
deployment = Deployment(
|
|
||||||
model_name="unsupported-test",
|
|
||||||
litellm_params=LiteLLM_Params(
|
|
||||||
model="unsupported/test-model",
|
|
||||||
use_in_pass_through=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should not raise an error, but log a warning
|
|
||||||
router._initialize_deployment_for_pass_through(
|
|
||||||
deployment=deployment,
|
|
||||||
custom_llm_provider="unsupported_provider",
|
|
||||||
model="unsupported/test-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_deployment_when_pass_through_disabled():
|
def test_initialize_deployment_when_pass_through_disabled():
|
||||||
"""
|
"""
|
||||||
Test that initialization simply exits when use_in_pass_through is False
|
Test that initialization simply exits when use_in_pass_through is False
|
||||||
|
|
202
tests/store_model_in_db_tests/test_adding_passthrough_model.py
Normal file
202
tests/store_model_in_db_tests/test_adding_passthrough_model.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
"""
|
||||||
|
Test adding a pass through assemblyai model + api key + api base to the db
|
||||||
|
wait 20 seconds
|
||||||
|
make request
|
||||||
|
|
||||||
|
Cases to cover
|
||||||
|
1. user points api base to <proxy-base>/assemblyai
|
||||||
|
2. user points api base to <proxy-base>/asssemblyai/us
|
||||||
|
3. user points api base to <proxy-base>/assemblyai/eu
|
||||||
|
4. Bad API Key / credential - 401
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import assemblyai as aai
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
TEST_MASTER_KEY = "sk-1234"
|
||||||
|
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||||
|
US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai"
|
||||||
|
EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai"
|
||||||
|
ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_all_assemblyai_models_from_db():
|
||||||
|
"""
|
||||||
|
Delete all assemblyai models from the db
|
||||||
|
"""
|
||||||
|
print("Deleting all assemblyai models from the db.......")
|
||||||
|
model_list_response = httpx.get(
|
||||||
|
url=f"{PROXY_BASE_URL}/v2/model/info",
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
)
|
||||||
|
response_data = model_list_response.json()
|
||||||
|
print("model list response", json.dumps(response_data, indent=4, default=str))
|
||||||
|
# Filter for only AssemblyAI models
|
||||||
|
assemblyai_models = [
|
||||||
|
model
|
||||||
|
for model in response_data["data"]
|
||||||
|
if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai"
|
||||||
|
]
|
||||||
|
|
||||||
|
for model in assemblyai_models:
|
||||||
|
model_id = model["model_info"]["id"]
|
||||||
|
httpx.post(
|
||||||
|
url=f"{PROXY_BASE_URL}/model/delete",
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
json={"id": model_id},
|
||||||
|
)
|
||||||
|
print("Deleted all assemblyai models from the db")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup_assemblyai_models():
|
||||||
|
"""
|
||||||
|
Fixture to clean up AssemblyAI models before and after each test
|
||||||
|
"""
|
||||||
|
# Clean up before test
|
||||||
|
_delete_all_assemblyai_models_from_db()
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Clean up after test
|
||||||
|
_delete_all_assemblyai_models_from_db()
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_assemblyai_passthrough():
|
||||||
|
"""
|
||||||
|
Test adding a pass through assemblyai model + api key + api base to the db
|
||||||
|
wait 20 seconds
|
||||||
|
make request
|
||||||
|
"""
|
||||||
|
add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com")
|
||||||
|
virtual_key = create_virtual_key()
|
||||||
|
# make request
|
||||||
|
make_assemblyai_basic_transcribe_request(
|
||||||
|
virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_e2e_assemblyai_passthrough_eu():
|
||||||
|
"""
|
||||||
|
Test adding a pass through assemblyai model + api key + api base to the db
|
||||||
|
wait 20 seconds
|
||||||
|
make request
|
||||||
|
"""
|
||||||
|
add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com")
|
||||||
|
virtual_key = create_virtual_key()
|
||||||
|
# make request
|
||||||
|
make_assemblyai_basic_transcribe_request(
|
||||||
|
virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_assemblyai_routes_with_bad_api_key():
|
||||||
|
"""
|
||||||
|
Test AssemblyAI endpoints with invalid API key to ensure proper error handling
|
||||||
|
"""
|
||||||
|
bad_api_key = "sk-12222"
|
||||||
|
payload = {
|
||||||
|
"audio_url": "https://assembly.ai/wildfires.mp3",
|
||||||
|
"audio_end_at": 280,
|
||||||
|
"audio_start_from": 10,
|
||||||
|
"auto_chapters": True,
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {bad_api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test EU endpoint
|
||||||
|
eu_response = httpx.post(
|
||||||
|
f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
eu_response.status_code == 401
|
||||||
|
), f"Expected 401 unauthorized, got {eu_response.status_code}"
|
||||||
|
|
||||||
|
# Test US endpoint
|
||||||
|
us_response = httpx.post(
|
||||||
|
f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
us_response.status_code == 401
|
||||||
|
), f"Expected 401 unauthorized, got {us_response.status_code}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_virtual_key():
|
||||||
|
"""
|
||||||
|
Create a virtual key
|
||||||
|
"""
|
||||||
|
response = httpx.post(
|
||||||
|
url=f"{PROXY_BASE_URL}/key/generate",
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
return response.json()["token"]
|
||||||
|
|
||||||
|
|
||||||
|
def add_assembly_ai_model_to_db(
|
||||||
|
api_base: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL
|
||||||
|
"""
|
||||||
|
print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR))
|
||||||
|
response = httpx.post(
|
||||||
|
url=f"{PROXY_BASE_URL}/model/new",
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
json={
|
||||||
|
"model_name": "assemblyai/*",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "assemblyai/*",
|
||||||
|
"custom_llm_provider": "assemblyai",
|
||||||
|
"api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR),
|
||||||
|
"api_base": api_base,
|
||||||
|
"use_in_pass_through": True,
|
||||||
|
},
|
||||||
|
"model_info": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def make_assemblyai_basic_transcribe_request(
|
||||||
|
virtual_key: str, assemblyai_base_url: str
|
||||||
|
):
|
||||||
|
print("making basic transcribe request to assemblyai passthrough")
|
||||||
|
|
||||||
|
# Replace with your API key
|
||||||
|
aai.settings.api_key = f"Bearer {virtual_key}"
|
||||||
|
aai.settings.base_url = assemblyai_base_url
|
||||||
|
|
||||||
|
# URL of the file to transcribe
|
||||||
|
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||||
|
|
||||||
|
# You can also transcribe a local file by passing in a file path
|
||||||
|
# FILE_URL = './path/to/file.mp3'
|
||||||
|
|
||||||
|
transcriber = aai.Transcriber()
|
||||||
|
transcript = transcriber.transcribe(FILE_URL)
|
||||||
|
print(transcript)
|
||||||
|
print(transcript.id)
|
||||||
|
if transcript.id:
|
||||||
|
transcript.delete_by_id(transcript.id)
|
||||||
|
else:
|
||||||
|
pytest.fail("Failed to get transcript id")
|
||||||
|
|
||||||
|
if transcript.status == aai.TranscriptStatus.error:
|
||||||
|
print(transcript.error)
|
||||||
|
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
|
||||||
|
else:
|
||||||
|
print(transcript.text)
|
|
@ -1,5 +1,5 @@
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { Form } from "antd";
|
import { Form, Select } from "antd";
|
||||||
import { TextInput, Text } from "@tremor/react";
|
import { TextInput, Text } from "@tremor/react";
|
||||||
import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd";
|
import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd";
|
||||||
import { UploadOutlined } from "@ant-design/icons";
|
import { UploadOutlined } from "@ant-design/icons";
|
||||||
|
@ -72,9 +72,21 @@ const ProviderSpecificFields: React.FC<ProviderSpecificFieldsProps> = ({
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{selectedProviderEnum === Providers.AssemblyAI && (
|
||||||
|
<Form.Item
|
||||||
|
rules={[{ required: true, message: "Required" }]}
|
||||||
|
label="API Base"
|
||||||
|
name="api_base"
|
||||||
|
>
|
||||||
|
<Select placeholder="Select API Base">
|
||||||
|
<Select.Option value="https://api.assemblyai.com">https://api.assemblyai.com</Select.Option>
|
||||||
|
<Select.Option value="https://api.eu.assemblyai.com">https://api.eu.assemblyai.com</Select.Option>
|
||||||
|
</Select>
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
|
||||||
{(selectedProviderEnum === Providers.Azure ||
|
{(selectedProviderEnum === Providers.Azure ||
|
||||||
selectedProviderEnum === Providers.OpenAI_Compatible ||
|
selectedProviderEnum === Providers.OpenAI_Compatible
|
||||||
selectedProviderEnum === Providers.AssemblyAI
|
|
||||||
) && (
|
) && (
|
||||||
<Form.Item
|
<Form.Item
|
||||||
rules={[{ required: true, message: "Required" }]}
|
rules={[{ required: true, message: "Required" }]}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue