docs(huggingface.md): add text-classification to huggingface docs

This commit is contained in:
Krrish Dholakia 2024-05-10 14:39:14 -07:00
parent 50be25d11a
commit d4d175030f
3 changed files with 177 additions and 9 deletions

View file

@ -230,6 +230,8 @@ def read_tgi_conv_models():
def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference"
@ -401,10 +403,7 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, headers)
if optional_params.get("hf_task") is None:
task = get_hf_task_for_model(model)
else:
task = optional_params.get("hf_task") # type: ignore
task = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(