updated hf tests

This commit is contained in:
ishaan-jaff 2023-09-27 17:49:30 -07:00
parent 156d4f27de
commit 2c3da9acbb

View file

@ -139,6 +139,99 @@ def test_completion_with_litellm_call_id():
# pytest.fail(f"Error occurred: {e}")
# test_completion_nlp_cloud()
######### HUGGING FACE TESTS ########################
#####################################################
"""
HF Tests we should pass
- TGI:
- Pro Inference API
- Deployed Endpoint
- Coversational
- Free Inference API
- Deployed Endpoint
- Neither TGI or Coversational
- Free Inference API
- Deployed Endpoint
"""
#####################################################
#####################################################
# Test util to sort models to TGI, conv, None
def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert(model_type == "text-generation-inference")
model = "meta-llama/Llama-2-7b-hf"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert(model_type == "text-generation-inference")
model = "facebook/blenderbot-400M-distill"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert(model_type == "conversational")
model = "facebook/blenderbot-3B"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert(model_type == "conversational")
# neither Conv or None
model = "roneneldan/TinyStories-3M"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert(model_type == None)
# test_get_hf_task_for_model()
# litellm.set_verbose=False
# ################### Hugging Face TGI models ########################
# # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
# def hf_test_completion_tgi():
# try:
# response = litellm.completion(
# model="huggingface/glaiveai/glaive-coder-7b",
# messages=[{ "content": "Hello, how are you?","role": "user"}],
# api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud",
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_tgi()
# ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv():
# try:
# response = litellm.completion(
# model="huggingface/facebook/blenderbot-3B",
# messages=[{ "content": "Hello, how are you?","role": "user"}],
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_conv()
# ################### Hugging Face Neither TGI or Conversational models ########################
# # Neither TGI or Conversational
# def hf_test_completion_none_task():
# try:
# user_message = "My name is Merve and my favorite"
# messages = [{ "content": user_message,"role": "user"}]
# response = completion(
# model="huggingface/roneneldan/TinyStories-3M",
# messages=messages,
# api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task()
########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api():
# # failing on circle ci commenting out
# try:
@ -181,24 +274,6 @@ def test_completion_with_litellm_call_id():
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# using Non TGI or conversational LLMs
def hf_test_completion():
try:
# litellm.set_verbose=True
user_message = "My name is Merve and my favorite"
messages = [{ "content": user_message,"role": "user"}]
response = completion(
model="huggingface/roneneldan/TinyStories-3M",
messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
task=None,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# hf_test_completion()
# this should throw an exception, to trigger https://logs.litellm.ai/
# def hf_test_error_logs():