diff --git a/.circleci/config.yml b/.circleci/config.yml index e7e36f1c93..1af15b03a5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1552,6 +1552,7 @@ jobs: pip install "pytest-retry==1.6.3" pip install "pytest-mock==3.12.0" pip install "pytest-asyncio==0.21.1" + pip install "assemblyai==0.37.0" - run: name: Build Docker image command: docker build -t my-app:latest -f ./docker/Dockerfile.database . diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 713925c638..a131e6ce85 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -248,6 +248,7 @@ class LiteLLMRoutes(enum.Enum): "/azure", "/openai", "/assemblyai", + "/eu.assemblyai", ] anthropic_routes = [ diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index dce20b9775..3da970234f 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -20,9 +20,13 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( ) from litellm.secret_managers.main import get_secret_str +from .passthrough_endpoint_router import PassthroughEndpointRouter + router = APIRouter() default_vertex_config = None +passthrough_endpoint_router = PassthroughEndpointRouter() + def create_request_copy(request: Request): return { @@ -68,8 +72,9 @@ async def gemini_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore - secret_name="GEMINI_API_KEY" + gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials( + custom_llm_provider="gemini", + region_name=None, ) if gemini_api_key is None: raise Exception( @@ -126,7 +131,10 @@ async def cohere_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY") + cohere_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="cohere", + region_name=None, + ) ## check for streaming is_streaming_request = False @@ -175,7 +183,10 @@ async def anthropic_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - anthropic_api_key = litellm.utils.get_secret(secret_name="ANTHROPIC_API_KEY") + anthropic_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="anthropic", + region_name=None, + ) ## check for streaming is_streaming_request = False @@ -297,18 +308,34 @@ def _is_bedrock_agent_runtime_route(endpoint: str) -> bool: methods=["GET", "POST", "PUT", "DELETE", "PATCH"], tags=["AssemblyAI Pass-through", "pass-through"], ) +@router.api_route( + "/eu.assemblyai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["AssemblyAI EU Pass-through", "pass-through"], +) async def assemblyai_proxy_route( endpoint: str, request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): + from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import ( + AssemblyAIPassthroughLoggingHandler, + ) + """ [Docs](https://api.assemblyai.com) """ - base_target_url = "https://api.assemblyai.com" + # Set base URL based on the route + assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( + url=str(request.url) + ) + base_target_url = ( + AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region( + region=assembly_region + ) + ) encoded_endpoint = httpx.URL(endpoint).path - # Ensure endpoint starts with '/' for proper URL construction if not encoded_endpoint.startswith("/"): encoded_endpoint = "/" + encoded_endpoint @@ -318,7 +345,10 @@ async def assemblyai_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - assemblyai_api_key = litellm.utils.get_secret(secret_name="ASSEMBLYAI_API_KEY") + assemblyai_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="assemblyai", + region_name=assembly_region, + ) ## check for streaming is_streaming_request = False @@ -366,7 +396,10 @@ async def azure_proxy_route( "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." ) # Add or update query parameters - azure_api_key = get_secret_str(secret_name="AZURE_API_KEY") + azure_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="azure", + region_name=None, + ) if azure_api_key is None: raise Exception( "Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure." @@ -400,7 +433,10 @@ async def openai_proxy_route( """ base_target_url = "https://api.openai.com" # Add or update query parameters - openai_api_key = get_secret_str(secret_name="OPENAI_API_KEY") + openai_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="openai", + region_name=None, + ) if openai_api_key is None: raise Exception( "Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI." diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py index bf0f0f73a5..2418a435b4 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/assembly_passthrough_logging_handler.py @@ -1,12 +1,13 @@ import asyncio import json -import os import time from datetime import datetime -from typing import Optional, TypedDict +from typing import Literal, Optional, TypedDict +from urllib.parse import urlparse import httpx +import litellm from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import ( @@ -18,7 +19,7 @@ from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggin class AssemblyAITranscriptResponse(TypedDict, total=False): id: str - language_model: str + speech_model: str acoustic_model: str language_code: str status: str @@ -27,7 +28,8 @@ class AssemblyAITranscriptResponse(TypedDict, total=False): class AssemblyAIPassthroughLoggingHandler: def __init__(self): - self.assembly_ai_base_url = "https://api.assemblyai.com/v2" + self.assembly_ai_base_url = "https://api.assemblyai.com" + self.assembly_ai_eu_base_url = "https://eu.assemblyai.com" """ The base URL for the AssemblyAI API """ @@ -43,8 +45,6 @@ class AssemblyAIPassthroughLoggingHandler: The maximum number of polling attempts for the AssemblyAI API. """ - self.assemblyai_api_key = os.environ.get("ASSEMBLYAI_API_KEY") - def assemblyai_passthrough_logging_handler( self, httpx_response: httpx.Response, @@ -90,27 +90,34 @@ class AssemblyAIPassthroughLoggingHandler: """ from ..pass_through_endpoints import pass_through_endpoint_logging - model = response_body.get("model", "") - verbose_proxy_logger.debug("response body", json.dumps(response_body, indent=4)) + model = response_body.get("speech_model", "") + verbose_proxy_logger.debug( + "response body %s", json.dumps(response_body, indent=4) + ) kwargs["model"] = model kwargs["custom_llm_provider"] = "assemblyai" + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = "assemblyai" transcript_id = response_body.get("id") if transcript_id is None: raise ValueError( "Transcript ID is required to log the cost of the transcription" ) - transcript_response = self._poll_assembly_for_transcript_response(transcript_id) + transcript_response = self._poll_assembly_for_transcript_response( + transcript_id=transcript_id, url_route=url_route + ) verbose_proxy_logger.debug( - "finished polling assembly for transcript response- got transcript response", + "finished polling assembly for transcript response- got transcript response %s", json.dumps(transcript_response, indent=4), ) if transcript_response: - cost = self.get_cost_for_assembly_transcript(transcript_response) + cost = self.get_cost_for_assembly_transcript( + speech_model=model, + transcript_response=transcript_response, + ) kwargs["response_cost"] = cost - logging_obj.model_call_details["model"] = logging_obj.model - # Make standard logging object for Vertex AI standard_logging_object = get_standard_logging_object_payload( kwargs=kwargs, @@ -157,7 +164,11 @@ class AssemblyAIPassthroughLoggingHandler: return {} return dict(transcript_response) - def _get_assembly_transcript(self, transcript_id: str) -> Optional[dict]: + def _get_assembly_transcript( + self, + transcript_id: str, + request_region: Optional[Literal["eu"]] = None, + ) -> Optional[dict]: """ Get the transcript details from AssemblyAI API @@ -167,10 +178,25 @@ class AssemblyAIPassthroughLoggingHandler: Returns: Optional[dict]: Transcript details if successful, None otherwise """ + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + + _base_url = ( + self.assembly_ai_eu_base_url + if request_region == "eu" + else self.assembly_ai_base_url + ) + _api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="assemblyai", + region_name=request_region, + ) + if _api_key is None: + raise ValueError("AssemblyAI API key not found") try: - url = f"{self.assembly_ai_base_url}/transcript/{transcript_id}" + url = f"{_base_url}/v2/transcript/{transcript_id}" headers = { - "Authorization": f"Bearer {self.assemblyai_api_key}", + "Authorization": f"Bearer {_api_key}", "Content-Type": "application/json", } @@ -179,11 +205,15 @@ class AssemblyAIPassthroughLoggingHandler: return response.json() except Exception as e: - verbose_proxy_logger.debug(f"Error getting AssemblyAI transcript: {str(e)}") + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}" + ) return None def _poll_assembly_for_transcript_response( - self, transcript_id: str + self, + transcript_id: str, + url_route: Optional[str] = None, ) -> Optional[AssemblyAITranscriptResponse]: """ Poll the status of the transcript until it is completed or timeout (30 minutes) @@ -191,7 +221,12 @@ class AssemblyAIPassthroughLoggingHandler: for _ in range( self.max_polling_attempts ): # 180 attempts * 10s = 30 minutes max - transcript = self._get_assembly_transcript(transcript_id) + transcript = self._get_assembly_transcript( + request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url( + url=url_route + ), + transcript_id=transcript_id, + ) if transcript is None: return None if ( @@ -205,6 +240,7 @@ class AssemblyAIPassthroughLoggingHandler: @staticmethod def get_cost_for_assembly_transcript( transcript_response: AssemblyAITranscriptResponse, + speech_model: str, ) -> Optional[float]: """ Get the cost for the assembly transcript @@ -212,7 +248,50 @@ class AssemblyAIPassthroughLoggingHandler: _audio_duration = transcript_response.get("audio_duration") if _audio_duration is None: return None - return _audio_duration * 0.0001 + _cost_per_second = ( + AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model( + speech_model=speech_model + ) + ) + if _cost_per_second is None: + return None + return _audio_duration * _cost_per_second + + @staticmethod + def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]: + """ + Get the cost per second for the assembly model. + Falls back to assemblyai/nano if the specific speech model info cannot be found. + """ + try: + # First try with the provided speech model + try: + model_info = litellm.get_model_info( + model=speech_model, + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass # Continue to fallback if model not found + + # Fallback to assemblyai/nano if speech model info not found + try: + model_info = litellm.get_model_info( + model="assemblyai/nano", + custom_llm_provider="assemblyai", + ) + if model_info and model_info.get("input_cost_per_second") is not None: + return model_info.get("input_cost_per_second") + except Exception: + pass + + return None + except Exception as e: + verbose_proxy_logger.exception( + f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}" + ) + return None @staticmethod def _should_log_request(request_method: str) -> bool: @@ -220,3 +299,25 @@ class AssemblyAIPassthroughLoggingHandler: only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost """ return request_method == "POST" + + @staticmethod + def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]: + """ + Get the region from the URL + """ + if url is None: + return None + if urlparse(url).hostname == "eu.assemblyai.com": + return "eu" + return None + + @staticmethod + def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str: + """ + Get the base URL for the AssemblyAI API + if region == "eu", return "https://api.eu.assemblyai.com" + else return "https://api.assemblyai.com" + """ + if region == "eu": + return "https://api.eu.assemblyai.com" + return "https://api.assemblyai.com" diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py new file mode 100644 index 0000000000..adf7d0f30c --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -0,0 +1,93 @@ +from typing import Dict, Optional + +from litellm._logging import verbose_logger +from litellm.secret_managers.main import get_secret_str + + +class PassthroughEndpointRouter: + """ + Use this class to Set/Get credentials for pass-through endpoints + """ + + def __init__(self): + self.credentials: Dict[str, str] = {} + + def set_pass_through_credentials( + self, + custom_llm_provider: str, + api_base: Optional[str], + api_key: Optional[str], + ): + """ + Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI. + + Args: + custom_llm_provider: The provider of the pass-through endpoint + api_base: The base URL of the pass-through endpoint + api_key: The API key for the pass-through endpoint + """ + credential_name = self._get_credential_name_for_provider( + custom_llm_provider=custom_llm_provider, + region_name=self._get_region_name_from_api_base( + api_base=api_base, custom_llm_provider=custom_llm_provider + ), + ) + if api_key is None: + raise ValueError("api_key is required for setting pass-through credentials") + self.credentials[credential_name] = api_key + + def get_credentials( + self, + custom_llm_provider: str, + region_name: Optional[str], + ) -> Optional[str]: + credential_name = self._get_credential_name_for_provider( + custom_llm_provider=custom_llm_provider, + region_name=region_name, + ) + verbose_logger.debug( + f"Pass-through llm endpoints router, looking for credentials for {credential_name}" + ) + if credential_name in self.credentials: + verbose_logger.debug(f"Found credentials for {credential_name}") + return self.credentials[credential_name] + else: + verbose_logger.debug( + f"No credentials found for {credential_name}, looking for env variable" + ) + _env_variable_name = ( + self._get_default_env_variable_name_passthrough_endpoint( + custom_llm_provider=custom_llm_provider, + ) + ) + return get_secret_str(_env_variable_name) + + def _get_credential_name_for_provider( + self, + custom_llm_provider: str, + region_name: Optional[str], + ) -> str: + if region_name is None: + return f"{custom_llm_provider.upper()}_API_KEY" + return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY" + + def _get_region_name_from_api_base( + self, + custom_llm_provider: str, + api_base: Optional[str], + ) -> Optional[str]: + """ + Get the region name from the API base. + + Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that. + """ + if custom_llm_provider == "assemblyai": + if api_base and "eu" in api_base: + return "eu" + return None + + @staticmethod + def _get_default_env_variable_name_passthrough_endpoint( + custom_llm_provider: str, + ) -> str: + return f"{custom_llm_provider.upper()}_API_KEY" diff --git a/litellm/router.py b/litellm/router.py index b61c30dd57..6d6a92d3b1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4180,8 +4180,14 @@ class Router: vertex_credentials=deployment.litellm_params.vertex_credentials, ) else: - verbose_router_logger.error( - f"Unsupported provider - {custom_llm_provider} for pass-through endpoints" + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + + passthrough_endpoint_router.set_pass_through_credentials( + custom_llm_provider=custom_llm_provider, + api_base=deployment.litellm_params.api_base, + api_key=deployment.litellm_params.api_key, ) pass pass diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a2e5448dd1..76d7f008bb 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1870,6 +1870,7 @@ class LlmProviders(str, Enum): LANGFUSE = "langfuse" HUMANLOOP = "humanloop" TOPAZ = "topaz" + ASSEMBLYAI = "assemblyai" # Create a set of all provider values for quick lookup diff --git a/tests/pass_through_tests/test_assembly_ai.py b/tests/pass_through_tests/test_assembly_ai.py index dd9f91a0b4..2d01ef2c1b 100644 --- a/tests/pass_through_tests/test_assembly_ai.py +++ b/tests/pass_through_tests/test_assembly_ai.py @@ -1,5 +1,5 @@ """ -This test ensures that the proxy can passthrough anthropic requests +This test ensures that the proxy can passthrough requests to assemblyai """ import pytest diff --git a/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py b/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py index 8279d4c988..963f1ad6ef 100644 --- a/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py +++ b/tests/pass_through_unit_tests/test_assemblyai_unit_tests_passthrough.py @@ -41,7 +41,6 @@ from litellm.proxy.pass_through_endpoints.success_handler import ( @pytest.fixture def assembly_handler(): handler = AssemblyAIPassthroughLoggingHandler() - handler.assemblyai_api_key = "test-key" return handler @@ -66,21 +65,27 @@ def test_should_log_request(): def test_get_assembly_transcript(assembly_handler, mock_transcript_response): """ Test that the _get_assembly_transcript method calls GET /v2/transcript/{transcript_id} + and uses the test key returned by the mocked get_credentials. """ - with patch("httpx.get") as mock_get: - mock_get.return_value.json.return_value = mock_transcript_response - mock_get.return_value.raise_for_status.return_value = None + # Patch get_credentials to return "test-key" + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials", + return_value="test-key", + ): + with patch("httpx.get") as mock_get: + mock_get.return_value.json.return_value = mock_transcript_response + mock_get.return_value.raise_for_status.return_value = None - transcript = assembly_handler._get_assembly_transcript("test-transcript-id") - assert transcript == mock_transcript_response + transcript = assembly_handler._get_assembly_transcript("test-transcript-id") + assert transcript == mock_transcript_response - mock_get.assert_called_once_with( - "https://api.assemblyai.com/v2/transcript/test-transcript-id", - headers={ - "Authorization": "Bearer test-key", - "Content-Type": "application/json", - }, - ) + mock_get.assert_called_once_with( + "https://api.assemblyai.com/v2/transcript/test-transcript-id", + headers={ + "Authorization": "Bearer test-key", + "Content-Type": "application/json", + }, + ) def test_poll_assembly_for_transcript_response( @@ -89,18 +94,24 @@ def test_poll_assembly_for_transcript_response( """ Test that the _poll_assembly_for_transcript_response method returns the correct transcript response """ - with patch("httpx.get") as mock_get: - mock_get.return_value.json.return_value = mock_transcript_response - mock_get.return_value.raise_for_status.return_value = None + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials", + return_value="test-key", + ): + with patch("httpx.get") as mock_get: + mock_get.return_value.json.return_value = mock_transcript_response + mock_get.return_value.raise_for_status.return_value = None - # Override polling settings for faster test - assembly_handler.polling_interval = 0.01 - assembly_handler.max_polling_attempts = 2 + # Override polling settings for faster test + assembly_handler.polling_interval = 0.01 + assembly_handler.max_polling_attempts = 2 - transcript = assembly_handler._poll_assembly_for_transcript_response( - "test-transcript-id" - ) - assert transcript == AssemblyAITranscriptResponse(**mock_transcript_response) + transcript = assembly_handler._poll_assembly_for_transcript_response( + "test-transcript-id", + ) + assert transcript == AssemblyAITranscriptResponse( + **mock_transcript_response + ) def test_is_assemblyai_route(): diff --git a/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py b/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py new file mode 100644 index 0000000000..6e8296876a --- /dev/null +++ b/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py @@ -0,0 +1,134 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +sys.path.insert(0, os.path.abspath("../..")) # + +import unittest +from unittest.mock import patch +from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, +) + +passthrough_endpoint_router = PassthroughEndpointRouter() + +""" +1. Basic Usage + - Set OpenAI, AssemblyAI, Anthropic, Cohere credentials + - GET credentials from passthrough_endpoint_router + +2. Basic Usage - when not using DB +- No credentials set +- call GET credentials with provider name, assert that it reads the secret from the environment variable + + +3. Unit test for _get_default_env_variable_name_passthrough_endpoint +""" + + +class TestPassthroughEndpointRouter(unittest.TestCase): + def setUp(self): + self.router = PassthroughEndpointRouter() + + def test_set_and_get_credentials(self): + """ + 1. Basic Usage: + - Set credentials for OpenAI, AssemblyAI, Anthropic, Cohere + - GET credentials from passthrough_endpoint_router (from the memory store when available) + """ + + # OpenAI: standard (no region-specific logic) + self.router.set_pass_through_credentials("openai", None, "openai_key") + self.assertEqual(self.router.get_credentials("openai", None), "openai_key") + + # AssemblyAI: using an API base that contains 'eu' should trigger regional logic. + api_base_eu = "https://api.eu.assemblyai.com" + self.router.set_pass_through_credentials( + "assemblyai", api_base_eu, "assemblyai_key" + ) + # When calling get_credentials, pass the region "eu" (extracted from the API base) + self.assertEqual( + self.router.get_credentials("assemblyai", "eu"), "assemblyai_key" + ) + + # Anthropic: no region set + self.router.set_pass_through_credentials("anthropic", None, "anthropic_key") + self.assertEqual( + self.router.get_credentials("anthropic", None), "anthropic_key" + ) + + # Cohere: no region set + self.router.set_pass_through_credentials("cohere", None, "cohere_key") + self.assertEqual(self.router.get_credentials("cohere", None), "cohere_key") + + def test_get_credentials_from_env(self): + """ + 2. Basic Usage - when not using the database: + - No credentials set in memory + - Call get_credentials with provider name and expect it to read from the environment variable (via get_secret_str) + """ + # Patch the get_secret_str function within the router's module. + with patch( + "litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str" + ) as mock_get_secret: + mock_get_secret.return_value = "env_openai_key" + # For "openai", if credentials are not set, it should fallback to the env variable. + result = self.router.get_credentials("openai", None) + self.assertEqual(result, "env_openai_key") + mock_get_secret.assert_called_once_with("OPENAI_API_KEY") + + with patch( + "litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str" + ) as mock_get_secret: + mock_get_secret.return_value = "env_cohere_key" + result = self.router.get_credentials("cohere", None) + self.assertEqual(result, "env_cohere_key") + mock_get_secret.assert_called_once_with("COHERE_API_KEY") + + with patch( + "litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str" + ) as mock_get_secret: + mock_get_secret.return_value = "env_anthropic_key" + result = self.router.get_credentials("anthropic", None) + self.assertEqual(result, "env_anthropic_key") + mock_get_secret.assert_called_once_with("ANTHROPIC_API_KEY") + + with patch( + "litellm.proxy.pass_through_endpoints.passthrough_endpoint_router.get_secret_str" + ) as mock_get_secret: + mock_get_secret.return_value = "env_azure_key" + result = self.router.get_credentials("azure", None) + self.assertEqual(result, "env_azure_key") + mock_get_secret.assert_called_once_with("AZURE_API_KEY") + + def test_default_env_variable_method(self): + """ + 3. Unit test for _get_default_env_variable_name_passthrough_endpoint: + - Should return the provider in uppercase followed by _API_KEY. + """ + self.assertEqual( + PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint( + "openai" + ), + "OPENAI_API_KEY", + ) + self.assertEqual( + PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint( + "assemblyai" + ), + "ASSEMBLYAI_API_KEY", + ) + self.assertEqual( + PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint( + "anthropic" + ), + "ANTHROPIC_API_KEY", + ) + self.assertEqual( + PassthroughEndpointRouter._get_default_env_variable_name_passthrough_endpoint( + "cohere" + ), + "COHERE_API_KEY", + ) diff --git a/tests/router_unit_tests/test_router_adding_deployments.py b/tests/router_unit_tests/test_router_adding_deployments.py index b5e2d4a526..fca3f147e5 100644 --- a/tests/router_unit_tests/test_router_adding_deployments.py +++ b/tests/router_unit_tests/test_router_adding_deployments.py @@ -76,27 +76,6 @@ def test_initialize_deployment_for_pass_through_missing_params(): ) -def test_initialize_deployment_for_pass_through_unsupported_provider(): - """ - Test initialization with an unsupported provider - """ - router = Router(model_list=[]) - deployment = Deployment( - model_name="unsupported-test", - litellm_params=LiteLLM_Params( - model="unsupported/test-model", - use_in_pass_through=True, - ), - ) - - # Should not raise an error, but log a warning - router._initialize_deployment_for_pass_through( - deployment=deployment, - custom_llm_provider="unsupported_provider", - model="unsupported/test-model", - ) - - def test_initialize_deployment_when_pass_through_disabled(): """ Test that initialization simply exits when use_in_pass_through is False diff --git a/tests/store_model_in_db_tests/test_adding_passthrough_model.py b/tests/store_model_in_db_tests/test_adding_passthrough_model.py new file mode 100644 index 0000000000..ad26e19bd6 --- /dev/null +++ b/tests/store_model_in_db_tests/test_adding_passthrough_model.py @@ -0,0 +1,202 @@ +""" +Test adding a pass through assemblyai model + api key + api base to the db +wait 20 seconds +make request + +Cases to cover +1. user points api base to /assemblyai +2. user points api base to /asssemblyai/us +3. user points api base to /assemblyai/eu +4. Bad API Key / credential - 401 +""" + +import time +import assemblyai as aai +import pytest +import httpx +import os +import json + +TEST_MASTER_KEY = "sk-1234" +PROXY_BASE_URL = "http://0.0.0.0:4000" +US_BASE_URL = f"{PROXY_BASE_URL}/assemblyai" +EU_BASE_URL = f"{PROXY_BASE_URL}/eu.assemblyai" +ASSEMBLYAI_API_KEY_ENV_VAR = "TEST_SPECIAL_ASSEMBLYAI_API_KEY" + + +def _delete_all_assemblyai_models_from_db(): + """ + Delete all assemblyai models from the db + """ + print("Deleting all assemblyai models from the db.......") + model_list_response = httpx.get( + url=f"{PROXY_BASE_URL}/v2/model/info", + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + ) + response_data = model_list_response.json() + print("model list response", json.dumps(response_data, indent=4, default=str)) + # Filter for only AssemblyAI models + assemblyai_models = [ + model + for model in response_data["data"] + if model.get("litellm_params", {}).get("custom_llm_provider") == "assemblyai" + ] + + for model in assemblyai_models: + model_id = model["model_info"]["id"] + httpx.post( + url=f"{PROXY_BASE_URL}/model/delete", + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + json={"id": model_id}, + ) + print("Deleted all assemblyai models from the db") + + +@pytest.fixture(autouse=True) +def cleanup_assemblyai_models(): + """ + Fixture to clean up AssemblyAI models before and after each test + """ + # Clean up before test + _delete_all_assemblyai_models_from_db() + + # Run the test + yield + + # Clean up after test + _delete_all_assemblyai_models_from_db() + + +def test_e2e_assemblyai_passthrough(): + """ + Test adding a pass through assemblyai model + api key + api base to the db + wait 20 seconds + make request + """ + add_assembly_ai_model_to_db(api_base="https://api.assemblyai.com") + virtual_key = create_virtual_key() + # make request + make_assemblyai_basic_transcribe_request( + virtual_key=virtual_key, assemblyai_base_url=US_BASE_URL + ) + + pass + + +def test_e2e_assemblyai_passthrough_eu(): + """ + Test adding a pass through assemblyai model + api key + api base to the db + wait 20 seconds + make request + """ + add_assembly_ai_model_to_db(api_base="https://api.eu.assemblyai.com") + virtual_key = create_virtual_key() + # make request + make_assemblyai_basic_transcribe_request( + virtual_key=virtual_key, assemblyai_base_url=EU_BASE_URL + ) + + pass + + +def test_assemblyai_routes_with_bad_api_key(): + """ + Test AssemblyAI endpoints with invalid API key to ensure proper error handling + """ + bad_api_key = "sk-12222" + payload = { + "audio_url": "https://assembly.ai/wildfires.mp3", + "audio_end_at": 280, + "audio_start_from": 10, + "auto_chapters": True, + } + headers = { + "Authorization": f"Bearer {bad_api_key}", + "Content-Type": "application/json", + } + + # Test EU endpoint + eu_response = httpx.post( + f"{PROXY_BASE_URL}/eu.assemblyai/v2/transcript", headers=headers, json=payload + ) + assert ( + eu_response.status_code == 401 + ), f"Expected 401 unauthorized, got {eu_response.status_code}" + + # Test US endpoint + us_response = httpx.post( + f"{PROXY_BASE_URL}/assemblyai/v2/transcript", headers=headers, json=payload + ) + assert ( + us_response.status_code == 401 + ), f"Expected 401 unauthorized, got {us_response.status_code}" + + +def create_virtual_key(): + """ + Create a virtual key + """ + response = httpx.post( + url=f"{PROXY_BASE_URL}/key/generate", + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + json={}, + ) + print(response.json()) + return response.json()["token"] + + +def add_assembly_ai_model_to_db( + api_base: str, +): + """ + Add the assemblyai model to the db - makes a http request to the /model/new endpoint on PROXY_BASE_URL + """ + print("assmbly ai api key", os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR)) + response = httpx.post( + url=f"{PROXY_BASE_URL}/model/new", + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + json={ + "model_name": "assemblyai/*", + "litellm_params": { + "model": "assemblyai/*", + "custom_llm_provider": "assemblyai", + "api_key": os.getenv(ASSEMBLYAI_API_KEY_ENV_VAR), + "api_base": api_base, + "use_in_pass_through": True, + }, + "model_info": {}, + }, + ) + print(response.json()) + pass + + +def make_assemblyai_basic_transcribe_request( + virtual_key: str, assemblyai_base_url: str +): + print("making basic transcribe request to assemblyai passthrough") + + # Replace with your API key + aai.settings.api_key = f"Bearer {virtual_key}" + aai.settings.base_url = assemblyai_base_url + + # URL of the file to transcribe + FILE_URL = "https://assembly.ai/wildfires.mp3" + + # You can also transcribe a local file by passing in a file path + # FILE_URL = './path/to/file.mp3' + + transcriber = aai.Transcriber() + transcript = transcriber.transcribe(FILE_URL) + print(transcript) + print(transcript.id) + if transcript.id: + transcript.delete_by_id(transcript.id) + else: + pytest.fail("Failed to get transcript id") + + if transcript.status == aai.TranscriptStatus.error: + print(transcript.error) + pytest.fail(f"Failed to transcribe file error: {transcript.error}") + else: + print(transcript.text) diff --git a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx index 38f7edea90..f48763a6f4 100644 --- a/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx +++ b/ui/litellm-dashboard/src/components/add_model/provider_specific_fields.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { Form } from "antd"; +import { Form, Select } from "antd"; import { TextInput, Text } from "@tremor/react"; import { Row, Col, Typography, Button as Button2, Upload, UploadProps } from "antd"; import { UploadOutlined } from "@ant-design/icons"; @@ -72,9 +72,21 @@ const ProviderSpecificFields: React.FC = ({ )} + {selectedProviderEnum === Providers.AssemblyAI && ( + + + + )} + {(selectedProviderEnum === Providers.Azure || - selectedProviderEnum === Providers.OpenAI_Compatible || - selectedProviderEnum === Providers.AssemblyAI + selectedProviderEnum === Providers.OpenAI_Compatible ) && (