This commit is contained in:
Krrish Dholakia 2025-03-18 10:48:28 -07:00
parent e1f0cffae7
commit f0da181215

View file

@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider model, custom_llm_provider
) )
if custom_llm_provider: if custom_llm_provider and (
if ( model.split("/")[0] != custom_llm_provider
model.split("/")[0] == custom_llm_provider ): # handle scenario where model="azure/*" and custom_llm_provider="azure"
): # handle scenario where model="azure/*" and custom_llm_provider="azure" model = custom_llm_provider + "/" + model
model = model.replace("{}/".format(custom_llm_provider), "")
return model, custom_llm_provider, dynamic_api_key, api_base
if api_key and api_key.startswith("os.environ/"): if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key) dynamic_api_key = get_secret_str(api_key)
# check if llm provider part of model name # check if llm provider part of model name
if ( if (
model.split("/", 1)[0] in litellm.provider_list model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list_set and model.split("/", 1)[0] not in litellm.model_list_set
@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "snowflake": elif custom_llm_provider == "snowflake":
api_base = ( api_base = (
api_base api_base
or get_secret("SNOWFLAKE_API_BASE") or get_secret_str("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore ) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base)) raise Exception("api base needs to be a string. api_base={}".format(api_base))