This commit is contained in:
Krrish Dholakia 2023-09-29 11:33:58 -07:00
parent bb50729a18
commit f04d50d119
15 changed files with 84 additions and 70 deletions

View file

@ -1187,7 +1187,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
model = model.split("/", 1)[1]
return model, custom_llm_provider
# check if model in known model provider list
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
## openai - chatcompletion + text completion
if model in litellm.open_ai_chat_completion_models:
custom_llm_provider = "openai"
@ -1208,15 +1208,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
## vertex - text + chat models
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
custom_llm_provider = "vertex_ai"
## huggingface
elif model in litellm.huggingface_models:
custom_llm_provider = "huggingface"
## ai21
elif model in litellm.ai21_models:
custom_llm_provider = "ai21"
## together_ai
elif model in litellm.together_ai_models:
custom_llm_provider = "together_ai"
## aleph_alpha
elif model in litellm.aleph_alpha_models:
custom_llm_provider = "aleph_alpha"
@ -1231,6 +1225,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
custom_llm_provider = "petals"
if custom_llm_provider is None or custom_llm_provider=="":
print()
print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m")
print()
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")
return model, custom_llm_provider
except Exception as e: