This commit is contained in:
Krrish Dholakia 2025-03-18 10:48:28 -07:00
parent e1f0cffae7
commit f0da181215

View file

@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider
)
if custom_llm_provider:
if (
model.split("/")[0] == custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = model.replace("{}/".format(custom_llm_provider), "")
return model, custom_llm_provider, dynamic_api_key, api_base
if custom_llm_provider and (
model.split("/")[0] != custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = custom_llm_provider + "/" + model
if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key)
# check if llm provider part of model name
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list_set
@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "snowflake":
api_base = (
api_base
or get_secret("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
api_base
or get_secret_str("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base))