mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 16:53:12 +00:00
111 lines
3.3 KiB
Python
111 lines
3.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
|
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
|
|
from ..test_cases.test_case import TestCase
|
|
|
|
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
|
|
|
|
|
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
|
provider_id = models[model_id].provider_id
|
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
|
provider = providers[provider_id]
|
|
if provider.provider_type in (
|
|
"remote::openai",
|
|
"remote::anthropic",
|
|
"remote::gemini",
|
|
"remote::groq",
|
|
"remote::llama-openai-compat",
|
|
):
|
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
|
|
|
|
|
def get_llama_model(client_with_models, model_id):
|
|
models = {}
|
|
for m in client_with_models.models.list():
|
|
models[m.identifier] = m
|
|
models[m.provider_resource_id] = m
|
|
|
|
assert model_id in models, f"Model {model_id} not found"
|
|
|
|
model = models[model_id]
|
|
ids = (model.identifier, model.provider_resource_id)
|
|
for mid in ids:
|
|
if resolve_model(mid):
|
|
return mid
|
|
|
|
return model.metadata.get("llama_model", None)
|
|
|
|
|
|
def get_llama_tokenizer():
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
|
|
tokenizer = Tokenizer.get_instance()
|
|
formatter = ChatFormat(tokenizer)
|
|
return tokenizer, formatter
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:batch_completion",
|
|
],
|
|
)
|
|
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
content_batch = tc["contents"]
|
|
response = client_with_models.inference.batch_completion(
|
|
content_batch=content_batch,
|
|
model_id=text_model_id,
|
|
sampling_params={
|
|
"max_tokens": 50,
|
|
},
|
|
)
|
|
assert len(response.batch) == len(content_batch)
|
|
for i, r in enumerate(response.batch):
|
|
print(f"response {i}: {r.content}")
|
|
assert len(r.content) > 10
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:batch_completion",
|
|
],
|
|
)
|
|
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
|
tc = TestCase(test_case)
|
|
qa_pairs = tc["qa_pairs"]
|
|
|
|
message_batch = [
|
|
[
|
|
{
|
|
"role": "user",
|
|
"content": qa["question"],
|
|
}
|
|
]
|
|
for qa in qa_pairs
|
|
]
|
|
|
|
response = client_with_models.inference.batch_chat_completion(
|
|
messages_batch=message_batch,
|
|
model_id=text_model_id,
|
|
)
|
|
assert len(response.batch) == len(qa_pairs)
|
|
for i, r in enumerate(response.batch):
|
|
print(f"response {i}: {r.completion_message.content}")
|
|
assert len(r.completion_message.content) > 0
|
|
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()
|