work for hf inference endpoint

This commit is contained in:
Krrish Dholakia 2023-09-11 18:37:55 -07:00
parent f946f61b4c
commit bab36c2c6f
4 changed files with 28 additions and 19 deletions

View file

@ -1,5 +1,5 @@
## Uses the huggingface text generation inference API
import os
import os, copy
import json
from enum import Enum
import requests
@ -61,11 +61,20 @@ def completion(
else:
prompt = prompt_factory(model=model, messages=messages)
### MAP INPUT PARAMS
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
if "https://api-inference.huggingface.co/models" in completion_url:
inference_params = copy.deepcopy(optional_params)
inference_params.pop("details")
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
}
else:
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
}
## LOGGING
logging_obj.pre_call(
input=prompt,