Merge branch 'main' of https://github.com/meta-llama/llama-stack into add_nemo_customizer

This commit is contained in:
Ubuntu 2025-03-20 09:34:19 +00:00
commit f534b4c2ea
571 changed files with 229651 additions and 12956 deletions

View file

@ -0,0 +1,234 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import ToolChoice, ToolConfig
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
# the implementation details of those classes. More general
# (API-level) tests should be placed in tests/integration/inference/
#
# How to run this test:
#
# pytest tests/unit/providers/inference/test_remote_vllm.py \
# -v -s --tb=short --disable-warnings
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: Dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
self.send_response(code=200)
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8"))
self.request_handler = DelayedRequestHandler
def __enter__(self):
httpd = HTTPServer(("", 0), self.request_handler)
self.httpd = httpd
host, port = httpd.server_address
httpd_thread = threading.Thread(target=httpd.serve_forever)
httpd_thread.daemon = True # stop server if this thread terminates
httpd_thread.start()
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
inference_adapter = VLLMInferenceAdapter(config)
return inference_adapter
def __exit__(self, _exc_type, _exc_value, _traceback):
if self.httpd:
self.httpd.shutdown()
self.httpd.server_close()
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
yield mock_list
@pytest_asyncio.fixture(scope="module")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter.model_store = AsyncMock()
await inference_adapter.initialize()
return inference_adapter
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
@pytest.mark.asyncio
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
to support older versions of vLLM
"""
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(vllm_inference_adapter, "_nonstream_chat_completion") as mock_nonstream_completion:
# No tools but auto tool choice
await vllm_inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
mock_nonstream_completion.assert_called()
request = mock_nonstream_completion.call_args.args[0]
# Ensure tool_choice gets converted to none for older vLLM versions
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 0
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
caplog.set_level(logging.WARNING)
# Log when event loop is blocked for more than 200ms
loop.slow_callback_duration = 0.5
# Sleep for 500ms in our delayed http response
sleep_time = 0.5
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
mock_response = {
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1,
"modle": "mock-model",
"choices": [
{
"message": {"content": ""},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
async def do_chat_completion():
await inference_adapter.chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
inference_adapter.model_store = AsyncMock()
inference_adapter.model_store.get_model.return_value = mock_model
loop.run_until_complete(inference_adapter.initialize())
# Clear the logs so far and run the actual chat completion we care about
caplog.clear()
loop.run_until_complete(do_chat_completion())
# Ensure we don't have any asyncio warnings in the captured log
# records from our chat completion call. A message gets logged
# here any time we exceed the slow_callback_duration configured
# above.
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
assert not asyncio_warnings

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry, providable_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:
"""Test suite for testing provider configurations across all API types."""
def test_all_api_providers_exist(self):
provider_registry = get_provider_registry()
for api in providable_apis():
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
@pytest.mark.parametrize("api", providable_apis())
def test_api_providers(self, api):
provider_registry = get_provider_registry()
providers = provider_registry.get(api, {})
assert providers, f"No providers found for API type: {api}"
failures = []
for provider_type, provider_spec in providers.items():
try:
self._verify_provider_config(provider_type, provider_spec)
except Exception as e:
failures.append(f"Failed to verify {provider_type} config: {str(e)}")
if failures:
pytest.fail("\n".join(failures))
def _verify_provider_config(self, provider_type, provider_spec):
"""Helper method to verify a single provider configuration."""
# Get the config class
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
assert issubclass(config_type, BaseModel), f"{config_class_name} is not a subclass of BaseModel"
assert hasattr(config_type, "sample_run_config"), f"{config_class_name} does not have sample_run_config method"
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"

View file

@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import random
import numpy as np
import pytest
from llama_stack.apis.vector_io import Chunk
EMBEDDING_DIMENSION = 384
@pytest.fixture
def vector_db_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@pytest.fixture(scope="session")
def embedding_dimension() -> int:
return EMBEDDING_DIMENSION
@pytest.fixture(scope="session")
def sample_chunks():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session")
def sample_embeddings(sample_chunks):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorDB,
VectorDBStore,
)
from llama_stack.providers.inline.vector_io.qdrant.config import (
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_qdrant.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
return mock_vector_db
@pytest.fixture
def mock_vector_db_store(mock_vector_db) -> MagicMock:
mock_store = MagicMock(spec=VectorDBStore)
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
return mock_store
@pytest.fixture
def mock_api_service(sample_embeddings):
mock_api_service = MagicMock(spec=Inference)
mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
return mock_api_service
@pytest_asyncio.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
await adapter.shutdown()
__QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
vector_db_id,
sample_chunks,
sample_embeddings,
max_query_chunks,
expected_chunks,
) -> None:
assert qdrant_adapter is not None
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
assert index is not None
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
params={"max_chunks": max_query_chunks},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
# To by-pass attempt to convert a Mock to JSON
def _prepare_for_json(value: Any) -> str:
return str(value)
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,
sample_chunks,
) -> None:
# Initially, no collections
vector_db_id = mock_vector_db.identifier
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
# Register does not create a collection
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
await qdrant_adapter.register_vector_db(mock_vector_db)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
# First insert creates the collection
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
assert await qdrant_adapter.client.collection_exists(vector_db_id)
# Unregister deletes the collection
await qdrant_adapter.unregister_vector_db(vector_db_id)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
assert len((await qdrant_adapter.client.get_collections()).collections) == 0

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import sqlite3
import numpy as np
import pytest
import pytest_asyncio
import sqlite_vec
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
SQLiteVecIndex,
SQLiteVecVectorIOAdapter,
generate_chunk_id,
)
# This test is a unit test for the SQLiteVecVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_sqlite_vec.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
SQLITE_VEC_PROVIDER = "sqlite_vec"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def sqlite_connection(loop):
conn = sqlite3.connect(":memory:")
try:
conn.enable_load_extension(True)
sqlite_vec.load(conn)
yield conn
finally:
conn.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def sqlite_vec_index(sqlite_connection, embedding_dimension):
return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank")
@pytest.mark.asyncio
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2)
cur = sqlite_vec_index.connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
# since there are 10 chunks per document and batch size is 2
batch_size = 2
sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size)
cur = sqlite_vec_index.connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
# Ensure all chunk IDs are unique
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
await adapter.initialize()
yield adapter
await adapter.shutdown()
def test_generate_chunk_id():
chunks = [
Chunk(content="test", metadata={"document_id": "doc-1"}),
Chunk(content="test ", metadata={"document_id": "doc-1"}),
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
]
chunk_ids = sorted([generate_chunk_id(chunk.metadata["document_id"], chunk.content) for chunk in chunks])
assert chunk_ids == [
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]