mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
refactor(test): move tools, evals, datasetio, scoring and post training tests (#1401)
All of the tests from `llama_stack/providers/tests/` are now moved to `tests/integration`. I converted the `tools`, `scoring` and `datasetio` tests to use API. However, `eval` and `post_training` proved to be a bit challenging to leaving those. I think `post_training` should be relatively straightforward also. As part of this, I noticed that `wolfram_alpha` tool wasn't added to some of our commonly used distros so I added it. I am going to remove a lot of code duplication from distros next so while this looks like a one-off right now, it will go away and be there uniformly for all distros.
This commit is contained in:
parent
dd0db8038b
commit
abfbaf3c1b
51 changed files with 471 additions and 1245 deletions
|
@ -20,7 +20,7 @@ from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
|||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack.env import get_env_or_fail
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from .fixtures.recordable_mock import RecordableMock
|
||||
|
@ -84,6 +84,11 @@ def pytest_addoption(parser):
|
|||
default=None,
|
||||
help="Specify the embedding model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--judge-model",
|
||||
default=None,
|
||||
help="Specify the judge model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--embedding-dimension",
|
||||
type=int,
|
||||
|
@ -109,6 +114,7 @@ def provider_data():
|
|||
"TOGETHER_API_KEY": "together_api_key",
|
||||
"ANTHROPIC_API_KEY": "anthropic_api_key",
|
||||
"GROQ_API_KEY": "groq_api_key",
|
||||
"WOLFRAM_ALPHA_API_KEY": "wolfram_alpha_api_key",
|
||||
}
|
||||
provider_data = {}
|
||||
for key, value in keymap.items():
|
||||
|
@ -260,7 +266,9 @@ def inference_provider_type(llama_stack_client):
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension):
|
||||
def client_with_models(
|
||||
llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension, judge_model_id
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||
|
@ -274,6 +282,8 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
|
|||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||
if vision_model_id and vision_model_id not in model_ids:
|
||||
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
|
||||
if judge_model_id and judge_model_id not in model_ids:
|
||||
client.models.register(model_id=judge_model_id, provider_id=inference_providers[0])
|
||||
|
||||
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
|
||||
# try to find a provider that supports embeddings, if sentence-transformers is not available
|
||||
|
@ -328,6 +338,14 @@ def pytest_generate_tests(metafunc):
|
|||
if val is not None:
|
||||
id_parts.append(f"emb={get_short_id(val)}")
|
||||
|
||||
if "judge_model_id" in metafunc.fixturenames:
|
||||
params.append("judge_model_id")
|
||||
val = metafunc.config.getoption("--judge-model")
|
||||
print(f"judge_model_id: {val}")
|
||||
values.append(val)
|
||||
if val is not None:
|
||||
id_parts.append(f"judge={get_short_id(val)}")
|
||||
|
||||
if "embedding_dimension" in metafunc.fixturenames:
|
||||
params.append("embedding_dimension")
|
||||
val = metafunc.config.getoption("--embedding-dimension")
|
||||
|
|
5
tests/integration/datasetio/__init__.py
Normal file
5
tests/integration/datasetio/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
6
tests/integration/datasetio/test_dataset.csv
Normal file
6
tests/integration/datasetio/test_dataset.csv
Normal file
|
@ -0,0 +1,6 @@
|
|||
input_query,generated_answer,expected_answer,chat_completion_input
|
||||
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
|
||||
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
|
||||
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"
|
|
118
tests/integration/datasetio/test_datasetio.py
Normal file
118
tests/integration/datasetio/test_datasetio.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
|
||||
# -m "meta_reference"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
def register_dataset(llama_stack_client, for_generation=False, for_rag=False, dataset_id="test_dataset"):
|
||||
if for_rag:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
|
||||
else:
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
if for_generation:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"chat_completion_input": {"type": "chat_completion_input"},
|
||||
}
|
||||
elif for_rag:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"generated_answer": {"type": "string"},
|
||||
"context": {"type": "string"},
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": {"type": "string"},
|
||||
"input_query": {"type": "string"},
|
||||
"generated_answer": {"type": "string"},
|
||||
}
|
||||
|
||||
llama_stack_client.datasets.register(
|
||||
dataset_id=dataset_id,
|
||||
dataset_schema=dataset_schema,
|
||||
url=dict(uri=test_url),
|
||||
provider_id="localfs",
|
||||
)
|
||||
|
||||
|
||||
def test_datasets_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
|
||||
def test_register_dataset(llama_stack_client):
|
||||
register_dataset(llama_stack_client)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
assert response[0].identifier == "test_dataset"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# unregister a dataset that does not exist
|
||||
llama_stack_client.datasets.unregister("test_dataset2")
|
||||
|
||||
llama_stack_client.datasets.unregister("test_dataset")
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
llama_stack_client.datasets.unregister("test_dataset")
|
||||
|
||||
|
||||
def test_get_rows_paginated(llama_stack_client):
|
||||
register_dataset(llama_stack_client)
|
||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 3
|
||||
assert response.next_page_token == "3"
|
||||
|
||||
# iterate over all rows
|
||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=2,
|
||||
page_token=response.next_page_token,
|
||||
)
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 2
|
||||
assert response.next_page_token == "5"
|
6
tests/integration/datasetio/test_rag_dataset.csv
Normal file
6
tests/integration/datasetio/test_rag_dataset.csv
Normal file
|
@ -0,0 +1,6 @@
|
|||
input_query,context,generated_answer,expected_answer
|
||||
What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris
|
||||
Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg
|
||||
What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter
|
||||
What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City
|
||||
What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen
|
|
5
tests/integration/eval/__init__.py
Normal file
5
tests/integration/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
20
tests/integration/eval/constants.py
Normal file
20
tests/integration/eval/constants.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
JUDGE_PROMPT = """
|
||||
You will be given a question, a expected_answer, and a system_answer.
|
||||
Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.
|
||||
Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.
|
||||
Provide your feedback as follows:
|
||||
Feedback:::
|
||||
Total rating: (your rating, as a int between 0 and 5)
|
||||
Now here are the question, expected_answer, system_answer.
|
||||
Question: {input_query}
|
||||
Expected Answer: {expected_answer}
|
||||
System Answer: {generated_answer}
|
||||
Feedback:::
|
||||
Total rating:
|
||||
"""
|
183
tests/integration/eval/test_eval.py
Normal file
183
tests/integration/eval/test_eval.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||
from llama_stack.apis.eval.eval import (
|
||||
ModelCandidate,
|
||||
)
|
||||
from llama_stack.apis.inference import SamplingParams
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from ..datasetio.test_datasetio import register_dataset
|
||||
from .constants import JUDGE_PROMPT
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/eval/test_eval.py
|
||||
# -m "meta_reference_eval_together_inference_huggingface_datasetio"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class Testeval:
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmarks_list(self, eval_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
benchmarks_impl = eval_stack[Api.benchmarks]
|
||||
response = await benchmarks_impl.list_benchmarks()
|
||||
assert isinstance(response, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
|
||||
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasetio],
|
||||
eval_stack[Api.datasets],
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
|
||||
response = await datasets_impl.list_datasets()
|
||||
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset_for_eval",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = [
|
||||
"basic::equality",
|
||||
]
|
||||
benchmark_id = "meta-reference::app_eval"
|
||||
await benchmarks_impl.register_benchmark(
|
||||
benchmark_id=benchmark_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
response = await eval_impl.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
),
|
||||
scoring_params={
|
||||
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
|
||||
judge_model=judge_model,
|
||||
prompt_template=JUDGE_PROMPT,
|
||||
judge_score_regexes=[
|
||||
r"Total rating: (\d+)",
|
||||
r"rating: (\d+)",
|
||||
r"Rating: (\d+)",
|
||||
],
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
assert len(response.generations) == 3
|
||||
assert "basic::equality" in response.scores
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
|
||||
eval_impl, benchmarks_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasets],
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
|
||||
|
||||
scoring_functions = [
|
||||
"basic::subset_of",
|
||||
]
|
||||
|
||||
benchmark_id = "meta-reference::app_eval-2"
|
||||
await benchmarks_impl.register_benchmark(
|
||||
benchmark_id=benchmark_id,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
response = await eval_impl.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
),
|
||||
),
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
assert "basic::subset_of" in eval_response.scores
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
|
||||
eval_impl, benchmarks_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasets],
|
||||
)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) > 0
|
||||
if response[0].provider_id != "huggingface":
|
||||
pytest.skip("Only huggingface provider supports pre-registered remote datasets")
|
||||
|
||||
await datasets_impl.register_dataset(
|
||||
dataset_id="mmlu",
|
||||
dataset_schema={
|
||||
"input_query": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
},
|
||||
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
|
||||
metadata={
|
||||
"path": "llamastack/evals",
|
||||
"name": "evals__mmlu__details",
|
||||
"split": "train",
|
||||
},
|
||||
)
|
||||
|
||||
# register eval task
|
||||
await benchmarks_impl.register_benchmark(
|
||||
benchmark_id="meta-reference-mmlu",
|
||||
dataset_id="mmlu",
|
||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||
)
|
||||
|
||||
# list benchmarks
|
||||
response = await benchmarks_impl.list_benchmarks()
|
||||
assert len(response) > 0
|
||||
|
||||
benchmark_id = "meta-reference-mmlu"
|
||||
response = await eval_impl.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=dict(
|
||||
eval_candidate=ModelCandidate(
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(),
|
||||
),
|
||||
num_examples=3,
|
||||
),
|
||||
)
|
||||
job_status = await eval_impl.job_status(benchmark_id, response.job_id)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(benchmark_id, response.job_id)
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 3
|
5
tests/integration/post_training/__init__.py
Normal file
5
tests/integration/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
101
tests/integration/post_training/test_post_training.py
Normal file
101
tests/integration/post_training/test_post_training.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||
# -m "torchtune_post_training_huggingface_datasetio"
|
||||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME FIXME @yanxi0830 this needs to be migrated to use the API")
|
||||
class TestPostTraining:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervised_fine_tune(self, post_training_stack):
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=False,
|
||||
rank=8,
|
||||
alpha=16,
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="alpaca",
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type="adamw",
|
||||
lr=3e-4,
|
||||
lr_min=3e-5,
|
||||
weight_decay=0.1,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
max_steps_per_epoch=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
post_training_impl = post_training_stack
|
||||
response = await post_training_impl.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="Llama3.2-3B-Instruct",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config={},
|
||||
logger_config={},
|
||||
checkpoint_dir="null",
|
||||
)
|
||||
assert isinstance(response, PostTrainingJob)
|
||||
assert response.job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_jobs(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
jobs_list = await post_training_impl.get_training_jobs()
|
||||
assert isinstance(jobs_list, List)
|
||||
assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_status(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_status = await post_training_impl.get_training_job_status("1234")
|
||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||
assert job_status.job_uuid == "1234"
|
||||
assert job_status.status == JobStatus.completed
|
||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
post_training_impl = post_training_stack
|
||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||
assert job_artifacts.job_uuid == "1234"
|
||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||
assert job_artifacts.checkpoints[0].epoch == 0
|
||||
assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
from pytest import CollectReport
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.env import get_env_or_fail
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_list import (
|
||||
all_registered_models,
|
||||
|
@ -26,7 +27,6 @@ from llama_stack.models.llama.sku_list import (
|
|||
safety_models,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
|
||||
from .metadata import API_MAPS
|
||||
|
||||
|
|
5
tests/integration/scoring/__init__.py
Normal file
5
tests/integration/scoring/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
160
tests/integration/scoring/test_scoring.py
Normal file
160
tests/integration/scoring/test_scoring.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from ..datasetio.test_datasetio import register_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_judge_prompt_template():
|
||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
||||
|
||||
def test_scoring_functions_list(llama_stack_client):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
response = llama_stack_client.scoring_functions.list()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {
|
||||
scoring_fns_list[0].identifier: None,
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
response = llama_stack_client.datasets.list()
|
||||
assert len(response) == 1
|
||||
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = {
|
||||
"llm-as-judge::base": dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=[
|
||||
"categorical_count",
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
|
||||
# score batch
|
||||
response = llama_stack_client.scoring.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
save_results_dataset=False,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping because this seems to be really slow")
|
||||
def test_scoring_score_with_aggregation_functions(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_fns_list = llama_stack_client.scoring_functions.list()
|
||||
scoring_functions = {}
|
||||
aggr_fns = [
|
||||
"accuracy",
|
||||
"median",
|
||||
"categorical_count",
|
||||
"average",
|
||||
]
|
||||
for x in scoring_fns_list:
|
||||
if x.provider_id == "llm-as-judge":
|
||||
aggr_fns = ["categorical_count"]
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="llm_as_judge",
|
||||
judge_model=judge_model_id,
|
||||
prompt_template=sample_judge_prompt_template,
|
||||
judge_score_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
elif x.provider_id == "basic" or x.provider_id == "braintrust":
|
||||
if "regex_parser" in x.identifier:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="regex_parser",
|
||||
parsing_regexes=[r"Score: (\d+)"],
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = dict(
|
||||
type="basic",
|
||||
aggregation_functions=aggr_fns,
|
||||
)
|
||||
else:
|
||||
scoring_functions[x.identifier] = None
|
||||
|
||||
response = llama_stack_client.scoring.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
)
|
||||
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
assert x in response.results
|
||||
assert len(response.results[x].score_rows) == len(rows.rows)
|
||||
assert len(response.results[x].aggregated_results) == len(aggr_fns)
|
66
tests/integration/tool_runtime/test_builtin_tools.py
Normal file
66
tests/integration/tool_runtime/test_builtin_tools.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_query():
|
||||
return "What are the latest developments in quantum computing?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wolfram_alpha_query():
|
||||
return "What is the square root of 16?"
|
||||
|
||||
|
||||
def test_web_search_tool(llama_stack_client, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
content = json.loads(response.content)
|
||||
assert "query" in content
|
||||
assert "top_k" in content
|
||||
assert len(content["top_k"]) > 0
|
||||
|
||||
first = content["top_k"][0]
|
||||
assert "title" in first
|
||||
assert "url" in first
|
||||
|
||||
|
||||
def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
||||
"""Test the wolfram alpha tool functionality."""
|
||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
print(response.content)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
content = json.loads(response.content)
|
||||
result = content["queryresult"]
|
||||
assert "success" in result
|
||||
assert result["success"]
|
||||
assert "pods" in result
|
||||
assert len(result["pods"]) > 0
|
|
@ -4,29 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types import Document
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def empty_vector_db_registry(llama_stack_client):
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
def client_with_empty_registry(client_with_models):
|
||||
def clear_registry():
|
||||
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
clear_registry()
|
||||
yield client_with_models
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
return vector_dbs
|
||||
# you must clean after the last test if you were running tests against
|
||||
# a stateful server instance
|
||||
clear_registry()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -63,9 +57,15 @@ def assert_valid_response(response):
|
|||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vector_db_registry, sample_documents):
|
||||
vector_db_id = single_entry_vector_db_registry[0]
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=sample_documents,
|
||||
chunk_size_in_tokens=512,
|
||||
vector_db_id=vector_db_id,
|
||||
|
@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with a direct match
|
||||
query1 = "programming language"
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query1,
|
||||
)
|
||||
|
@ -82,7 +82,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with semantic similarity
|
||||
query2 = "AI and brain-inspired computing"
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query2,
|
||||
)
|
||||
|
@ -91,7 +91,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with limit on number of results (max_chunks=2)
|
||||
query3 = "computer"
|
||||
response3 = llama_stack_client.vector_io.query(
|
||||
response3 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query3,
|
||||
params={"max_chunks": 2},
|
||||
|
@ -101,7 +101,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with threshold on similarity score
|
||||
query4 = "computer"
|
||||
response4 = llama_stack_client.vector_io.query(
|
||||
response4 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query4,
|
||||
params={"score_threshold": 0.01},
|
||||
|
@ -110,20 +110,20 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
assert all(score >= 0.01 for score in response4.scores)
|
||||
|
||||
|
||||
def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db_registry):
|
||||
providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"]
|
||||
def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
assert len(providers) > 0
|
||||
|
||||
vector_db_id = "test_vector_db"
|
||||
|
||||
llama_stack_client.vector_dbs.register(
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# list to check memory bank is successfully registered
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_db_id in available_vector_dbs
|
||||
|
||||
# URLs of documents to insert
|
||||
|
@ -144,14 +144,14 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Query for the name of method
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="What's the name of the fine-tunning method used?",
|
||||
)
|
||||
|
@ -159,7 +159,7 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
||||
|
||||
# Query for the name of model
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="Which Llama model is mentioned?",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue