mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
refactor(test): move tools, evals, datasetio, scoring and post training tests (#1401)
All of the tests from `llama_stack/providers/tests/` are now moved to `tests/integration`. I converted the `tools`, `scoring` and `datasetio` tests to use API. However, `eval` and `post_training` proved to be a bit challenging to leaving those. I think `post_training` should be relatively straightforward also. As part of this, I noticed that `wolfram_alpha` tool wasn't added to some of our commonly used distros so I added it. I am going to remove a lot of code duplication from distros next so while this looks like a one-off right now, it will go away and be there uniformly for all distros.
This commit is contained in:
parent
dd0db8038b
commit
abfbaf3c1b
51 changed files with 471 additions and 1245 deletions
|
@ -4,29 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from llama_stack_client.types import Document
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def empty_vector_db_registry(llama_stack_client):
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
llama_stack_client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
def client_with_empty_registry(client_with_models):
|
||||
def clear_registry():
|
||||
vector_dbs = [vector_db.identifier for vector_db in client_with_models.vector_dbs.list()]
|
||||
for vector_db_id in vector_dbs:
|
||||
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
clear_registry()
|
||||
yield client_with_models
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry):
|
||||
vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_dimension=384,
|
||||
)
|
||||
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
return vector_dbs
|
||||
# you must clean after the last test if you were running tests against
|
||||
# a stateful server instance
|
||||
clear_registry()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -63,9 +57,15 @@ def assert_valid_response(response):
|
|||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vector_db_registry, sample_documents):
|
||||
vector_db_id = single_entry_vector_db_registry[0]
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
def test_vector_db_insert_inline_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
vector_db_id = "test_vector_db"
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=sample_documents,
|
||||
chunk_size_in_tokens=512,
|
||||
vector_db_id=vector_db_id,
|
||||
|
@ -73,7 +73,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with a direct match
|
||||
query1 = "programming language"
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query1,
|
||||
)
|
||||
|
@ -82,7 +82,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with semantic similarity
|
||||
query2 = "AI and brain-inspired computing"
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query2,
|
||||
)
|
||||
|
@ -91,7 +91,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with limit on number of results (max_chunks=2)
|
||||
query3 = "computer"
|
||||
response3 = llama_stack_client.vector_io.query(
|
||||
response3 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query3,
|
||||
params={"max_chunks": 2},
|
||||
|
@ -101,7 +101,7 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
|
||||
# Query with threshold on similarity score
|
||||
query4 = "computer"
|
||||
response4 = llama_stack_client.vector_io.query(
|
||||
response4 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query=query4,
|
||||
params={"score_threshold": 0.01},
|
||||
|
@ -110,20 +110,20 @@ def test_vector_db_insert_inline_and_query(llama_stack_client, single_entry_vect
|
|||
assert all(score >= 0.01 for score in response4.scores)
|
||||
|
||||
|
||||
def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db_registry):
|
||||
providers = [p for p in llama_stack_client.providers.list() if p.api == "vector_io"]
|
||||
def test_vector_db_insert_from_url_and_query(client_with_empty_registry, sample_documents, embedding_model_id):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
assert len(providers) > 0
|
||||
|
||||
vector_db_id = "test_vector_db"
|
||||
|
||||
llama_stack_client.vector_dbs.register(
|
||||
client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=384,
|
||||
)
|
||||
|
||||
# list to check memory bank is successfully registered
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
|
||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||
assert vector_db_id in available_vector_dbs
|
||||
|
||||
# URLs of documents to insert
|
||||
|
@ -144,14 +144,14 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||
documents=documents,
|
||||
vector_db_id=vector_db_id,
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
# Query for the name of method
|
||||
response1 = llama_stack_client.vector_io.query(
|
||||
response1 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="What's the name of the fine-tunning method used?",
|
||||
)
|
||||
|
@ -159,7 +159,7 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
|
|||
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)
|
||||
|
||||
# Query for the name of model
|
||||
response2 = llama_stack_client.vector_io.query(
|
||||
response2 = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=vector_db_id,
|
||||
query="Which Llama model is mentioned?",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue