mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
work for hf inference endpoint
This commit is contained in:
parent
f946f61b4c
commit
bab36c2c6f
4 changed files with 28 additions and 19 deletions
|
@ -1,5 +1,5 @@
|
|||
## Uses the huggingface text generation inference API
|
||||
import os
|
||||
import os, copy
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests
|
||||
|
@ -61,11 +61,20 @@ def completion(
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
### MAP INPUT PARAMS
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
}
|
||||
if "https://api-inference.huggingface.co/models" in completion_url:
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("details")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True if "stream" in inference_params and inference_params["stream"] == True else False,
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False,
|
||||
}
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue