forked from phoenix/litellm-mirror
Merge pull request #3571 from BerriAI/litellm_hf_classifier_support
Huggingface classifier support
This commit is contained in:
commit
1aa567f3b5
6 changed files with 415 additions and 64 deletions
|
@ -13,6 +13,7 @@ import litellm
|
|||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# litellm.num_retries=3
|
||||
litellm.cache = None
|
||||
|
@ -1137,7 +1138,7 @@ def test_get_hf_task_for_model():
|
|||
model = "roneneldan/TinyStories-3M"
|
||||
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == None
|
||||
assert model_type == "text-generation"
|
||||
|
||||
|
||||
# test_get_hf_task_for_model()
|
||||
|
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
|
|||
# ################### Hugging Face TGI models ########################
|
||||
# # TGI model
|
||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||
def hf_test_completion_tgi():
|
||||
# litellm.set_verbose=True
|
||||
def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
{
|
||||
"generated_text": "<|assistant|>\nI'm",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": None,
|
||||
"prefill": [],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 28789,
|
||||
"text": "<",
|
||||
"logprob": -0.025222778,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.000003695488,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"text": "ass",
|
||||
"logprob": -0.0000019073486,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 11143,
|
||||
"text": "istant",
|
||||
"logprob": -0.000002026558,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.0000015497208,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28767,
|
||||
"text": ">",
|
||||
"logprob": -0.0000011920929,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"text": "\n",
|
||||
"logprob": -0.00009703636,
|
||||
"special": False,
|
||||
},
|
||||
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||
{
|
||||
"id": 28742,
|
||||
"text": "'",
|
||||
"logprob": -0.88183594,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"text": "m",
|
||||
"logprob": -0.00032639503,
|
||||
"special": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_test_completion_tgi():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
with patch("requests.post", side_effect=tgi_mock_post):
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -1191,6 +1269,40 @@ def hf_test_completion_tgi():
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# hf_test_completion_none_task()
|
||||
|
||||
|
||||
def mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
[
|
||||
{"label": "LABEL_0", "score": 0.9990691542625427},
|
||||
{"label": "LABEL_1", "score": 0.0009308889275416732},
|
||||
]
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_classifier_task():
|
||||
try:
|
||||
with patch("requests.post", side_effect=mock_post):
|
||||
litellm.set_verbose = True
|
||||
user_message = "I like you. I love you"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert isinstance(response.choices[0], litellm.Choices)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {str(e)}")
|
||||
|
||||
|
||||
########################### End of Hugging Face Tests ##############################################
|
||||
# def test_completion_hf_api():
|
||||
# # failing on circle ci commenting out
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue