Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -1606,30 +1606,33 @@ HF Tests we should pass
#####################################################
#####################################################
# Test util to sort models to TGI, conv, None
from litellm.llms.huggingface.chat.transformation import HuggingfaceChatConfig
def test_get_hf_task_for_model():
model = "glaiveai/glaive-coder-7b"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "meta-llama/Llama-2-7b-hf"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation-inference"
model = "facebook/blenderbot-400M-distill"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
model = "facebook/blenderbot-3B"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "conversational"
# neither Conv or None
model = "roneneldan/TinyStories-3M"
model_type, _ = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
model_type, _ = HuggingfaceChatConfig().get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == "text-generation"
@ -1717,14 +1720,17 @@ def tgi_mock_post(url, **kwargs):
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
client = HTTPHandler()
with patch("requests.post", side_effect=tgi_mock_post) as mock_client:
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
wait_for_model=True,
client=client,
)
mock_client.assert_called_once()
# Add any assertions-here to check the response
print(response)
assert "options" in mock_client.call_args.kwargs["data"]
@ -1862,13 +1868,15 @@ def mock_post(url, **kwargs):
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
client = HTTPHandler()
with patch.object(client, "post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
client=client,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
@ -3096,19 +3104,20 @@ async def test_completion_replicate_llama3(sync_mode):
response = completion(
model=model_name,
messages=messages,
max_tokens=10,
)
else:
response = await litellm.acompletion(
model=model_name,
messages=messages,
max_tokens=10,
)
print(f"ASYNC REPLICATE RESPONSE - {response}")
print(response)
print(f"REPLICATE RESPONSE - {response}")
# Add any assertions here to check the response
assert isinstance(response, litellm.ModelResponse)
assert len(response.choices[0].message.content.strip()) > 0
response_format_tests(response=response)
except litellm.APIError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -3745,22 +3754,6 @@ def test_mistral_anyscale_stream():
# pytest.fail(f"Error occurred: {e}")
#### Test A121 ###################
@pytest.mark.skip(reason="Local test")
def test_completion_ai21():
print("running ai21 j2light test")
litellm.set_verbose = True
model_name = "j2-light"
try:
response = completion(
model=model_name, messages=messages, max_tokens=100, temperature=0.8
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_ai21()
# test_completion_ai21()
## test deep infra