Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-04-14 08:51:59 -05:00 committed by GitHub
commit dd808a8c1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
57 changed files with 1392 additions and 671 deletions

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..test_cases.test_case import TestCase
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type not in ("inline::meta-reference",):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:batch_completion",
],
)
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
content_batch = tc["contents"]
response = client_with_models.inference.batch_completion(
content_batch=content_batch,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
},
)
assert len(response.batch) == len(content_batch)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.content}")
assert len(r.content) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:batch_completion",
],
)
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
qa_pairs = tc["qa_pairs"]
message_batch = [
[
{
"role": "user",
"content": qa["question"],
}
]
for qa in qa_pairs
]
response = client_with_models.inference.batch_chat_completion(
messages_batch=message_batch,
model_id=text_model_id,
)
assert len(response.batch) == len(qa_pairs)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.completion_message.content}")
assert len(r.completion_message.content) > 0
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
from time import sleep
import pytest
@ -66,15 +65,6 @@ def get_llama_model(client_with_models, model_id):
return model.metadata.get("llama_model", None)
def get_llama_tokenizer():
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
tokenizer = Tokenizer.get_instance()
formatter = ChatFormat(tokenizer)
return tokenizer, formatter
@pytest.mark.parametrize(
"test_case",
[
@ -273,41 +263,6 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
assert expected.lower() in message_content
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:ttft",
],
)
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
messages = tc["messages"]
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
from pydantic import TypeAdapter
from llama_stack.apis.inference import Message
tokenizer, formatter = get_llama_tokenizer()
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
encoded = formatter.encode_dialog_prompt(typed_messages, None)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
response = client_with_models.inference.chat_completion(
model_id=text_model_id,
messages=messages,
stream=False,
timeout=120, # Increase timeout to 2 minutes for large conversation history
)
message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
tokenizer, formatter = get_llama_tokenizer()
encoded = formatter.encode_content(message_content)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
@pytest.mark.parametrize(
"test_case",
[