mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
add hf support
This commit is contained in:
parent
d72dc244b1
commit
5c17f90173
3 changed files with 40 additions and 1 deletions
|
@ -10,6 +10,8 @@ azure_key = None
|
||||||
anthropic_key = None
|
anthropic_key = None
|
||||||
replicate_key = None
|
replicate_key = None
|
||||||
cohere_key = None
|
cohere_key = None
|
||||||
|
|
||||||
|
hugging_api_token = None
|
||||||
####### THREAD-SPECIFIC DATA ###################
|
####### THREAD-SPECIFIC DATA ###################
|
||||||
class MyLocal(threading.local):
|
class MyLocal(threading.local):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -44,7 +44,8 @@ def completion(
|
||||||
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
|
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
|
||||||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False
|
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||||
|
hugging_face = False
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
global new_response
|
global new_response
|
||||||
|
@ -273,6 +274,32 @@ def completion(
|
||||||
"total_tokens": prompt_tokens + completion_tokens
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
}
|
}
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif hugging_face == True:
|
||||||
|
import requests
|
||||||
|
API_URL = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
HF_TOKEN = get_secret("HF_TOKEN")
|
||||||
|
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
||||||
|
|
||||||
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
||||||
|
input_payload = {"inputs": prompt}
|
||||||
|
response = requests.post(API_URL, headers=headers, json=input_payload)
|
||||||
|
|
||||||
|
completion_response = response.json()[0]['generated_text']
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(encoding.encode(completion_response))
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_response
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = model
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||||
|
|
|
@ -26,6 +26,16 @@ def test_completion_claude():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_completion_hf_api():
|
||||||
|
try:
|
||||||
|
user_message = "write some code to find the sum of two numbers"
|
||||||
|
messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
response = completion(model="stabilityai/stablecode-completion-alpha-3b-4k", messages=messages, hugging_face=True)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
def test_completion_cohere():
|
def test_completion_cohere():
|
||||||
try:
|
try:
|
||||||
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue