mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm merge pr (#7161)
* build: merge branch * test: fix openai naming * fix(main.py): fix openai renaming * style: ignore function length for config factory * fix(sagemaker/): fix routing logic * fix: fix imports * fix: fix override
This commit is contained in:
parent
d5aae81c6d
commit
350cfc36f7
88 changed files with 3617 additions and 4421 deletions
|
@ -1606,30 +1606,33 @@ HF Tests we should pass
|
|||
#####################################################
|
||||
#####################################################
|
||||
# Test util to sort models to TGI, conv, None
|
||||
from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
|
||||
|
||||
|
||||
def test_get_hf_task_for_model():
|
||||
model = "glaiveai/glaive-coder-7b"
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation-inference"
|
||||
|
||||
model = "meta-llama/Llama-2-7b-hf"
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation-inference"
|
||||
|
||||
model = "facebook/blenderbot-400M-distill"
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "conversational"
|
||||
|
||||
model = "facebook/blenderbot-3B"
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "conversational"
|
||||
|
||||
# neither Conv or None
|
||||
model = "roneneldan/TinyStories-3M"
|
||||
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
|
||||
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
|
||||
print(f"model:{model}, model type: {model_type}")
|
||||
assert model_type == "text-generation"
|
||||
|
||||
|
@ -1717,14 +1720,17 @@ def tgi_mock_post(url, **kwargs):
|
|||
def test_hf_test_completion_tgi():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
|
||||
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
wait_for_model=True,
|
||||
client=client,
|
||||
)
|
||||
mock_client.assert_called_once()
|
||||
# Add any assertions-here to check the response
|
||||
print(response)
|
||||
assert "options" in mock_client.call_args.kwargs["data"]
|
||||
|
@ -1862,13 +1868,15 @@ def mock_post(url, **kwargs):
|
|||
|
||||
def test_hf_classifier_task():
|
||||
try:
|
||||
with patch("requests.post", side_effect=mock_post):
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post", side_effect=mock_post):
|
||||
litellm.set_verbose = True
|
||||
user_message = "I like you. I love you"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
client=client,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -3096,19 +3104,20 @@ async def test_completion_replicate_llama3(sync_mode):
|
|||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
||||
print(response)
|
||||
print(f"REPLICATE RESPONSE - {response}")
|
||||
# Add any assertions here to check the response
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert len(response.choices[0].message.content.strip()) > 0
|
||||
response_format_tests(response=response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -3745,22 +3754,6 @@ def test_mistral_anyscale_stream():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
#### Test A121 ###################
|
||||
@pytest.mark.skip(reason="Local test")
|
||||
def test_completion_ai21():
|
||||
print("running ai21 j2light test")
|
||||
litellm.set_verbose = True
|
||||
model_name = "j2-light"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name, messages=messages, max_tokens=100, temperature=0.8
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_ai21()
|
||||
# test_completion_ai21()
|
||||
## test deep infra
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue