fix linting errors

This commit is contained in:
Ishaan Jaff 2024-08-23 15:44:31 -07:00
parent 49b25db516
commit c3987745fe
2 changed files with 37 additions and 2 deletions

View file

@ -121,6 +121,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.text_to_speech.vertex_ai import VertexTextToSpeechAPI
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM
@ -165,6 +166,7 @@ bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
####### COMPLETION ENDPOINTS ################
@ -945,7 +947,6 @@ def completion(
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
)
logging.update_environment_variables(
model=model,
@ -4730,6 +4731,8 @@ def speech(
if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
logging_obj = kwargs.get("litellm_logging_obj", None)
response: Optional[HttpxBinaryResponseContent] = None
if custom_llm_provider == "openai":
api_base = (
@ -4815,6 +4818,38 @@ def speech(
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
)
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
from litellm.types.router import GenericLiteLLMParams
generic_optional_params = GenericLiteLLMParams(**kwargs)
api_base = generic_optional_params.api_base or ""
vertex_ai_project = (
generic_optional_params.vertex_project
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
generic_optional_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = generic_optional_params.vertex_credentials or get_secret(
"VERTEXAI_CREDENTIALS"
)
response = vertex_text_to_speech.audio_speech(
_is_async=aspeech,
vertex_credentials=vertex_credentials,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
timeout=timeout,
api_base=api_base,
model=model,
input=input,
voice=voice,
optional_params=optional_params,
logging_obj=logging_obj,
)
if response is None:
raise Exception(