raise better exception if llm provider isn't passed in or inferred

This commit is contained in:
Krrish Dholakia 2023-09-12 11:28:50 -07:00
parent 4acca3d4d9
commit baa69734b0
8 changed files with 63 additions and 1 deletions

View file

@ -931,6 +931,55 @@ def get_optional_params( # use the openai defaults
return optional_params
return optional_params
def get_llm_provider(model: str, custom_llm_provider: str = None):
try:
# check if llm provider provided
if custom_llm_provider:
return model, custom_llm_provider
# check if llm provider part of model name
if model.split("/",1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
return model, custom_llm_provider
# check if model in known model provider list
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
custom_llm_provider = "openai"
## cohere
elif model in litellm.cohere_models:
custom_llm_provider = "cohere"
## replicate
elif model in litellm.replicate_models:
custom_llm_provider = "replicate"
## openrouter
elif model in litellm.openrouter_models:
custom_llm_provider = "openrouter"
## vertex - text + chat models
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
custom_llm_provider = "vertex_ai"
## huggingface
elif model in litellm.huggingface_models:
custom_llm_provider = "huggingface"
## ai21
elif model in litellm.ai21_models:
custom_llm_provider = "ai21"
## together_ai
elif model in litellm.together_ai_models:
custom_llm_provider = "together_ai"
## aleph_alpha
elif model in litellm.aleph_alpha_models:
custom_llm_provider = "aleph_alpha"
## baseten
elif model in litellm.baseten_models:
custom_llm_provider = "baseten"
if custom_llm_provider is None:
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")
return model, custom_llm_provider
except Exception as e:
raise e
def get_max_tokens(model: str):
try: