mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(QA+UI) - e2e flow for adding assembly ai passthrough endpoints (#8337)
* add initial test for assembly ai * start using PassthroughEndpointRouter * migrate to lllm passthrough endpoints * add assembly ai as a known provider * fix PassthroughEndpointRouter * fix set_pass_through_credentials * working EU request to assembly ai pass through endpoint * add e2e test assembly * test_assemblyai_routes_with_bad_api_key * clean up pass through endpoint router * e2e testing for assembly ai pass through * test assembly ai e2e testing * delete assembly ai models * fix code quality * ui working assembly ai api base flow * fix install assembly ai * update model call details with kwargs for pass through logging * fix tracking assembly ai model in response * _handle_assemblyai_passthrough_logging * fix test_initialize_deployment_for_pass_through_unsupported_provider * TestPassthroughEndpointRouter * _get_assembly_transcript * fix assembly ai pt logging tests * fix assemblyai_proxy_route * fix _get_assembly_region_from_url
This commit is contained in:
parent
5dcb87a88b
commit
65c91cbbbc
13 changed files with 656 additions and 79 deletions
|
@ -1552,6 +1552,7 @@ jobs:
|
|||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-mock==3.12.0"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "assemblyai==0.37.0"
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
|
|
|
@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/azure",
|
||||
"/openai",
|
||||
"/assemblyai",
|
||||
"/eu.assemblyai",
|
||||
]
|
||||
|
||||
anthropic_routes = [
|
||||
|
|
|
@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
|
@ -68,8 +72,9 @@ async def gemini_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore
|
||||
secret_name="GEMINI_API_KEY"
|
||||
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="gemini",
|
||||
region_name=None,
|
||||
)
|
||||
if gemini_api_key is None:
|
||||
raise Exception(
|
||||
|
@ -126,7 +131,10 @@ async def cohere_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
|
||||
cohere_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="cohere",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -175,7 +183,10 @@ async def anthropic_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY")
|
||||
anthropic_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="anthropic",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
|||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||
)
|
||||
@router.api_route(
|
||||
"/eu.assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI EU Pass-through", "pass-through"],
|
||||
)
|
||||
async def assemblyai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
"""
|
||||
[Docs](https://api.assemblyai.com)
|
||||
"""
|
||||
base_target_url = "https://api.assemblyai.com"
|
||||
# Set base URL based on the route
|
||||
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=str(request.url)
|
||||
)
|
||||
base_target_url = (
|
||||
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
|
||||
region=assembly_region
|
||||
)
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
@ -318,7 +345,10 @@ async def assemblyai_proxy_route(
|
|||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY")
|
||||
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=assembly_region,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
|
@ -366,7 +396,10 @@ async def azure_proxy_route(
|
|||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
# Add or update query parameters
|
||||
azure_api_key = get_secret_str(secret_name="AZURE_API_KEY")
|
||||
azure_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="azure",
|
||||
region_name=None,
|
||||
)
|
||||
if azure_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||
|
@ -400,7 +433,10 @@ async def openai_proxy_route(
|
|||
"""
|
||||
base_target_url = "https://api.openai.com"
|
||||
# Add or update query parameters
|
||||
openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY")
|
||||
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="openai",
|
||||
region_name=None,
|
||||
)
|
||||
if openai_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, TypedDict
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
|
@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin
|
|||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
language_model: str
|
||||
speech_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
|
@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False):
|
|||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com/v2"
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY")
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
|
@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("model", "")
|
||||
verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4))
|
||||
model = response_body.get("speech_model", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"response body %s", json.dumps(response_body, indent=4)
|
||||
)
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(transcript_id)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(
|
||||
transcript_id=transcript_id, url_route=url_route
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response",
|
||||
"finished polling assembly for transcript response- got transcript response %s",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(transcript_response)
|
||||
cost = self.get_cost_for_assembly_transcript(
|
||||
speech_model=model,
|
||||
transcript_response=transcript_response,
|
||||
)
|
||||
kwargs["response_cost"] = cost
|
||||
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
|
@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]:
|
||||
def _get_assembly_transcript(
|
||||
self,
|
||||
transcript_id: str,
|
||||
request_region: Optional[Literal["eu"]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
|
@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
_base_url = (
|
||||
self.assembly_ai_eu_base_url
|
||||
if request_region == "eu"
|
||||
else self.assembly_ai_base_url
|
||||
)
|
||||
_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=request_region,
|
||||
)
|
||||
if _api_key is None:
|
||||
raise ValueError("AssemblyAI API key not found")
|
||||
try:
|
||||
url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}"
|
||||
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.assemblyai_api_key}",
|
||||
"Authorization": f"Bearer {_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}")
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self, transcript_id: str
|
||||
self,
|
||||
transcript_id: str,
|
||||
url_route: Optional[str] = None,
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
|
@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(transcript_id)
|
||||
transcript = self._get_assembly_transcript(
|
||||
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=url_route
|
||||
),
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
|
@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
speech_model: str,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
|
@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
return _audio_duration * 0.0001
|
||||
_cost_per_second = (
|
||||
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||
speech_model=speech_model
|
||||
)
|
||||
)
|
||||
if _cost_per_second is None:
|
||||
return None
|
||||
return _audio_duration * _cost_per_second
|
||||
|
||||
@staticmethod
|
||||
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||
"""
|
||||
Get the cost per second for the assembly model.
|
||||
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||
"""
|
||||
try:
|
||||
# First try with the provided speech model
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=speech_model,
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass # Continue to fallback if model not found
|
||||
|
||||
# Fallback to assemblyai/nano if speech model info not found
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model="assemblyai/nano",
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
|
@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler:
|
|||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||
"""
|
||||
Get the region from the URL
|
||||
"""
|
||||
if url is None:
|
||||
return None
|
||||
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||
"""
|
||||
Get the base URL for the AssemblyAI API
|
||||
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||
else return "https://api.assemblyai.com"
|
||||
"""
|
||||
if region == "eu":
|
||||
return "https://api.eu.assemblyai.com"
|
||||
return "https://api.assemblyai.com"
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class PassthroughEndpointRouter:
|
||||
"""
|
||||
Use this class to Set/Get credentials for pass-through endpoints
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.credentials: Dict[str, str] = {}
|
||||
|
||||
def set_pass_through_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
):
|
||||
"""
|
||||
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider of the pass-through endpoint
|
||||
api_base: The base URL of the pass-through endpoint
|
||||
api_key: The API key for the pass-through endpoint
|
||||
"""
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=self._get_region_name_from_api_base(
|
||||
api_base=api_base, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for setting pass-through credentials")
|
||||
self.credentials[credential_name] = api_key
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> Optional[str]:
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=region_name,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||
)
|
||||
if credential_name in self.credentials:
|
||||
verbose_logger.debug(f"Found credentials for {credential_name}")
|
||||
return self.credentials[credential_name]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"No credentials found for {credential_name}, looking for env variable"
|
||||
)
|
||||
_env_variable_name = (
|
||||
self._get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return get_secret_str(_env_variable_name)
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> str:
|
||||
if region_name is None:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
|
||||
|
||||
def _get_region_name_from_api_base(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the region name from the API base.
|
||||
|
||||
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
|
||||
"""
|
||||
if custom_llm_provider == "assemblyai":
|
||||
if api_base and "eu" in api_base:
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider: str,
|
||||
) -> str:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
|
@ -4180,8 +4180,14 @@ class Router:
|
|||
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||
)
|
||||
else:
|
||||
verbose_router_logger.error(
|
||||
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints"
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
passthrough_endpoint_router.set_pass_through_credentials(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=deployment.litellm_params.api_base,
|
||||
api_key=deployment.litellm_params.api_key,
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1870,6 +1870,7 @@ class LlmProviders(str, Enum):
|
|||
LANGFUSE = "langfuse"
|
||||
HUMANLOOP = "humanloop"
|
||||
TOPAZ = "topaz"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
|
||||
|
||||
# Create a set of all provider values for quick lookup
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
This test ensures that the proxy can passthrough anthropic requests
|
||||
This test ensures that the proxy can passthrough requests to assemblyai
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
|
|
@ -41,7 +41,6 @@ from litellm.proxy.pass_through_endpoints.success_handler import (
|
|||
@pytest.fixture
|
||||
def assembly_handler():
|
||||
handler = AssemblyAIPassthroughLoggingHandler()
|
||||
handler.assemblyai_api_key = "test-key"
|
||||
return handler
|
||||
|
||||
|
||||
|
@ -66,7 +65,13 @@ def test_should_log_request():
|
|||
def test_get_assembly_transcript(assembly_handler, mock_transcript_response):
|
||||
"""
|
||||
Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id}
|
||||
and uses the test key returned by the mocked get_credentials.
|
||||
"""
|
||||
# Patch get_credentials to return "test-key"
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
|
||||
return_value="test-key",
|
||||
):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value.json.return_value = mock_transcript_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
|
@ -89,6 +94,10 @@ def test_poll_assembly_for_transcript_response(
|
|||
"""
|
||||
Test that the _poll_assembly_for_transcript_response method returns the correct transcript response
|
||||
"""
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials",
|
||||
return_value="test-key",
|
||||
):
|
||||
with patch("httpx.get") as mock_get:
|
||||
mock_get.return_value.json.return_value = mock_transcript_response
|
||||
mock_get.return_value.raise_for_status.return_value = None
|
||||
|
@ -98,9 +107,11 @@ def test_poll_assembly_for_transcript_response(
|
|||
assembly_handler.max_polling_attempts = 2
|
||||
|
||||
transcript = assembly_handler._poll_assembly_for_transcript_response(
|
||||
"test-transcript-id"
|
||||
"test-transcript-id",
|
||||
)
|
||||
assert transcript == AssemblyAITranscriptResponse(
|
||||
**mock_transcript_response
|
||||
)
|
||||
assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response)
|
||||
|
||||
|
||||
def test_is_assemblyai_route():
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../..")) #
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
"""
|
||||
1. Basic Usage
|
||||
- Set OpenAI, AssemblyAI, Anthropic, Cohere credentials
|
||||
- GET credentials from passthrough_endpoint_router
|
||||
|
||||
2. Basic Usage - when not using DB
|
||||
- No credentials set
|
||||
- call GET credentials with provider name, assert that it reads the secret from the environment variable
|
||||
|
||||
|
||||
3. Unit test for _get_default_env_variable_name_passthrough_endpoint
|
||||
"""
|
||||
|
||||
|
||||
class TestPassthroughEndpointRouter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.router = PassthroughEndpointRouter()
|
||||
|
||||
def test_set_and_get_credentials(self):
|
||||
"""
|
||||
1. Basic Usage:
|
||||
- Set credentials for OpenAI, AssemblyAI, Anthropic, Cohere
|
||||
- GET credentials from passthrough_endpoint_router (from the memory store when available)
|
||||
"""
|
||||
|
||||
# OpenAI: standard (no region-specific logic)
|
||||
self.router.set_pass_through_credentials("openai", None, "openai_key")
|
||||
self.assertEqual(self.router.get_credentials("openai", None), "openai_key")
|
||||
|
||||
# AssemblyAI: using an API base that contains 'eu' should trigger regional logic.
|
||||
api_base_eu = "https://api.eu.assemblyai.com"
|
||||
self.router.set_pass_through_credentials(
|
||||
"assemblyai", api_base_eu, "assemblyai_key"
|
||||
)
|
||||
# When calling get_credentials, pass the region "eu" (extracted from the API base)
|
||||
self.assertEqual(
|
||||
self.router.get_credentials("assemblyai", "eu"), "assemblyai_key"
|
||||
)
|
||||
|
||||
# Anthropic: no region set
|
||||
self.router.set_pass_through_credentials("anthropic", None, "anthropic_key")
|
||||
self.assertEqual(
|
||||
self.router.get_credentials("anthropic", None), "anthropic_key"
|
||||
)
|
||||
|
||||
# Cohere: no region set
|
||||
self.router.set_pass_through_credentials("cohere", None, "cohere_key")
|
||||
self.assertEqual(self.router.get_credentials("cohere", None), "cohere_key")
|
||||
|
||||
def test_get_credentials_from_env(self):
|
||||
"""
|
||||
2. Basic Usage - when not using the database:
|
||||
- No credentials set in memory
|
||||
- Call get_credentials with provider name and expect it to read from the environment variable (via get_secret_str)
|
||||
"""
|
||||
# Patch the get_secret_str function within the router's module.
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = "env_openai_key"
|
||||
# For "openai", if credentials are not set, it should fallback to the env variable.
|
||||
result = self.router.get_credentials("openai", None)
|
||||
self.assertEqual(result, "env_openai_key")
|
||||
mock_get_secret.assert_called_once_with("OPENAI_API_KEY")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = "env_cohere_key"
|
||||
result = self.router.get_credentials("cohere", None)
|
||||
self.assertEqual(result, "env_cohere_key")
|
||||
mock_get_secret.assert_called_once_with("COHERE_API_KEY")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = "env_anthropic_key"
|
||||
result = self.router.get_credentials("anthropic", None)
|
||||
self.assertEqual(result, "env_anthropic_key")
|
||||
mock_get_secret.assert_called_once_with("ANTHROPIC_API_KEY")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str"
|
||||
) as mock_get_secret:
|
||||
mock_get_secret.return_value = "env_azure_key"
|
||||
result = self.router.get_credentials("azure", None)
|
||||
self.assertEqual(result, "env_azure_key")
|
||||
mock_get_secret.assert_called_once_with("AZURE_API_KEY")
|
||||
|
||||
def test_default_env_variable_method(self):
|
||||
"""
|
||||
3. Unit test for _get_default_env_variable_name_passthrough_endpoint:
|
||||
- Should return the provider in uppercase followed by _API_KEY.
|
||||
"""
|
||||
self.assertEqual(
|
||||
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||
"openai"
|
||||
),
|
||||
"OPENAI_API_KEY",
|
||||
)
|
||||
self.assertEqual(
|
||||
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||
"assemblyai"
|
||||
),
|
||||
"ASSEMBLYAI_API_KEY",
|
||||
)
|
||||
self.assertEqual(
|
||||
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||
"anthropic"
|
||||
),
|
||||
"ANTHROPIC_API_KEY",
|
||||
)
|
||||
self.assertEqual(
|
||||
PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint(
|
||||
"cohere"
|
||||
),
|
||||
"COHERE_API_KEY",
|
||||
)
|
|
@ -76,27 +76,6 @@ def test_initialize_deployment_for_pass_through_missing_params():
|
|||
)
|
||||
|
||||
|
||||
def test_initialize_deployment_for_pass_through_unsupported_provider():
|
||||
"""
|
||||
Test initialization with an unsupported provider
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
deployment = Deployment(
|
||||
model_name="unsupported-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="unsupported/test-model",
|
||||
use_in_pass_through=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Should not raise an error, but log a warning
|
||||
router._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider="unsupported_provider",
|
||||
model="unsupported/test-model",
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_deployment_when_pass_through_disabled():
|
||||
"""
|
||||
Test that initialization simply exits when use_in_pass_through is False
|
||||
|
|
202
tests/store_model_in_db_tests/test_adding_passthrough_model.py
Normal file
202
tests/store_model_in_db_tests/test_adding_passthrough_model.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
|
||||
Cases to cover
|
||||
1. user points api base to <proxy-base>/assemblyai
|
||||
2. user points api base to <proxy-base>/asssemblyai/us
|
||||
3. user points api base to <proxy-base>/assemblyai/eu
|
||||
4. Bad API Key / credential - 401
|
||||
"""
|
||||
|
||||
import time
|
||||
import assemblyai as aai
|
||||
import pytest
|
||||
import httpx
|
||||
import os
|
||||
import json
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai"
|
||||
EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai"
|
||||
ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY"
|
||||
|
||||
|
||||
def _delete_all_assemblyai_models_from_db():
|
||||
"""
|
||||
Delete all assemblyai models from the db
|
||||
"""
|
||||
print("Deleting all assemblyai models from the db.......")
|
||||
model_list_response = httpx.get(
|
||||
url=f"{PROXY_BASE_URL}/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
response_data = model_list_response.json()
|
||||
print("model list response", json.dumps(response_data, indent=4, default=str))
|
||||
# Filter for only AssemblyAI models
|
||||
assemblyai_models = [
|
||||
model
|
||||
for model in response_data["data"]
|
||||
if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai"
|
||||
]
|
||||
|
||||
for model in assemblyai_models:
|
||||
model_id = model["model_info"]["id"]
|
||||
httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/model/delete",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={"id": model_id},
|
||||
)
|
||||
print("Deleted all assemblyai models from the db")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_assemblyai_models():
|
||||
"""
|
||||
Fixture to clean up AssemblyAI models before and after each test
|
||||
"""
|
||||
# Clean up before test
|
||||
_delete_all_assemblyai_models_from_db()
|
||||
|
||||
# Run the test
|
||||
yield
|
||||
|
||||
# Clean up after test
|
||||
_delete_all_assemblyai_models_from_db()
|
||||
|
||||
|
||||
def test_e2e_assemblyai_passthrough():
|
||||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
"""
|
||||
add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com")
|
||||
virtual_key = create_virtual_key()
|
||||
# make request
|
||||
make_assemblyai_basic_transcribe_request(
|
||||
virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_e2e_assemblyai_passthrough_eu():
|
||||
"""
|
||||
Test adding a pass through assemblyai model + api key + api base to the db
|
||||
wait 20 seconds
|
||||
make request
|
||||
"""
|
||||
add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com")
|
||||
virtual_key = create_virtual_key()
|
||||
# make request
|
||||
make_assemblyai_basic_transcribe_request(
|
||||
virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_assemblyai_routes_with_bad_api_key():
|
||||
"""
|
||||
Test AssemblyAI endpoints with invalid API key to ensure proper error handling
|
||||
"""
|
||||
bad_api_key = "sk-12222"
|
||||
payload = {
|
||||
"audio_url": "https://assembly.ai/wildfires.mp3",
|
||||
"audio_end_at": 280,
|
||||
"audio_start_from": 10,
|
||||
"auto_chapters": True,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {bad_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Test EU endpoint
|
||||
eu_response = httpx.post(
|
||||
f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload
|
||||
)
|
||||
assert (
|
||||
eu_response.status_code == 401
|
||||
), f"Expected 401 unauthorized, got {eu_response.status_code}"
|
||||
|
||||
# Test US endpoint
|
||||
us_response = httpx.post(
|
||||
f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload
|
||||
)
|
||||
assert (
|
||||
us_response.status_code == 401
|
||||
), f"Expected 401 unauthorized, got {us_response.status_code}"
|
||||
|
||||
|
||||
def create_virtual_key():
|
||||
"""
|
||||
Create a virtual key
|
||||
"""
|
||||
response = httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/key/generate",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={},
|
||||
)
|
||||
print(response.json())
|
||||
return response.json()["token"]
|
||||
|
||||
|
||||
def add_assembly_ai_model_to_db(
|
||||
api_base: str,
|
||||
):
|
||||
"""
|
||||
Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL
|
||||
"""
|
||||
print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR))
|
||||
response = httpx.post(
|
||||
url=f"{PROXY_BASE_URL}/model/new",
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
json={
|
||||
"model_name": "assemblyai/*",
|
||||
"litellm_params": {
|
||||
"model": "assemblyai/*",
|
||||
"custom_llm_provider": "assemblyai",
|
||||
"api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR),
|
||||
"api_base": api_base,
|
||||
"use_in_pass_through": True,
|
||||
},
|
||||
"model_info": {},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
pass
|
||||
|
||||
|
||||
def make_assemblyai_basic_transcribe_request(
|
||||
virtual_key: str, assemblyai_base_url: str
|
||||
):
|
||||
print("making basic transcribe request to assemblyai passthrough")
|
||||
|
||||
# Replace with your API key
|
||||
aai.settings.api_key = f"Bearer {virtual_key}"
|
||||
aai.settings.base_url = assemblyai_base_url
|
||||
|
||||
# URL of the file to transcribe
|
||||
FILE_URL = "https://assembly.ai/wildfires.mp3"
|
||||
|
||||
# You can also transcribe a local file by passing in a file path
|
||||
# FILE_URL = './path/to/file.mp3'
|
||||
|
||||
transcriber = aai.Transcriber()
|
||||
transcript = transcriber.transcribe(FILE_URL)
|
||||
print(transcript)
|
||||
print(transcript.id)
|
||||
if transcript.id:
|
||||
transcript.delete_by_id(transcript.id)
|
||||
else:
|
||||
pytest.fail("Failed to get transcript id")
|
||||
|
||||
if transcript.status == aai.TranscriptStatus.error:
|
||||
print(transcript.error)
|
||||
pytest.fail(f"Failed to transcribe file error: {transcript.error}")
|
||||
else:
|
||||
print(transcript.text)
|
|
@ -1,5 +1,5 @@
|
|||
import React from "react";
|
||||
import { Form } from "antd";
|
||||
import { Form, Select } from "antd";
|
||||
import { TextInput, Text } from "@tremor/react";
|
||||
import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd";
|
||||
import { UploadOutlined } from "@ant-design/icons";
|
||||
|
@ -72,9 +72,21 @@ const ProviderSpecificFields: React.FC<ProviderSpecificFieldsProps> = ({
|
|||
</>
|
||||
)}
|
||||
|
||||
{selectedProviderEnum === Providers.AssemblyAI && (
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
label="API Base"
|
||||
name="api_base"
|
||||
>
|
||||
<Select placeholder="Select API Base">
|
||||
<Select.Option value="https://api.assemblyai.com">https://api.assemblyai.com</Select.Option>
|
||||
<Select.Option value="https://api.eu.assemblyai.com">https://api.eu.assemblyai.com</Select.Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{(selectedProviderEnum === Providers.Azure ||
|
||||
selectedProviderEnum === Providers.OpenAI_Compatible ||
|
||||
selectedProviderEnum === Providers.AssemblyAI
|
||||
selectedProviderEnum === Providers.OpenAI_Compatible
|
||||
) && (
|
||||
<Form.Item
|
||||
rules={[{ required: true, message: "Required" }]}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue