add hf support

This commit is contained in:
ishaan-jaff 2023-08-08 13:35:37 -07:00
parent 1223fed6ba
commit 4c76359fa1
3 changed files with 40 additions and 1 deletions

View file

@ -44,7 +44,8 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
hugging_face = False
):
try:
global new_response
@ -273,6 +274,32 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif hugging_face == True:
import requests
API_URL = f"https://api-inference.huggingface.co/models/{model}"
HF_TOKEN = get_secret("HF_TOKEN")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
input_payload = {"inputs": prompt}
response = requests.post(API_URL, headers=headers, json=input_payload)
completion_response = response.json()[0]['generated_text']
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)