diff --git a/litellm/llms/text_to_speech/vertex_ai.py b/litellm/llms/text_to_speech/vertex_ai.py index 335b936a76..01ab6edc05 100644 --- a/litellm/llms/text_to_speech/vertex_ai.py +++ b/litellm/llms/text_to_speech/vertex_ai.py @@ -47,7 +47,6 @@ class VertexTextToSpeechAPI(VertexLLM): def audio_speech( self, logging_obj, - _is_async: bool, vertex_project: Optional[str], vertex_location: Optional[str], vertex_credentials: Optional[str], @@ -56,6 +55,7 @@ class VertexTextToSpeechAPI(VertexLLM): model: str, input: str, voice: str, + _is_async: Optional[bool] = False, optional_params: Optional[dict] = None, **kwargs, ): diff --git a/litellm/main.py b/litellm/main.py index 8104bfd864..4323c24a6b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -121,6 +121,7 @@ from .llms.prompt_templates.factory import ( ) from .llms.sagemaker import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion +from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI from .llms.triton import TritonChatCompletion from .llms.vertex_ai_partner import VertexAIPartnerModels from .llms.vertex_httpx import VertexLLM @@ -165,6 +166,7 @@ bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() vertex_partner_models_chat_completion = VertexAIPartnerModels() +vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() ####### COMPLETION ENDPOINTS ################ @@ -945,7 +947,6 @@ def completion( text_completion=kwargs.get("text_completion"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), user_continue_message=kwargs.get("user_continue_message"), - ) logging.update_environment_variables( model=model, @@ -4730,6 +4731,8 @@ def speech( if max_retries is None: max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES + + logging_obj = kwargs.get("litellm_logging_obj", None) response: Optional[HttpxBinaryResponseContent] = None if custom_llm_provider == "openai": api_base = ( @@ -4815,6 +4818,38 @@ def speech( client=client, # pass AsyncOpenAI, OpenAI client aspeech=aspeech, ) + elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta": + from litellm.types.router import GenericLiteLLMParams + + generic_optional_params = GenericLiteLLMParams(**kwargs) + + api_base = generic_optional_params.api_base or "" + vertex_ai_project = ( + generic_optional_params.vertex_project + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + generic_optional_params.vertex_location + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = generic_optional_params.vertex_credentials or get_secret( + "VERTEXAI_CREDENTIALS" + ) + response = vertex_text_to_speech.audio_speech( + _is_async=aspeech, + vertex_credentials=vertex_credentials, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + timeout=timeout, + api_base=api_base, + model=model, + input=input, + voice=voice, + optional_params=optional_params, + logging_obj=logging_obj, + ) if response is None: raise Exception(