Make embedding generation go through inference (#606)

This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:47:50 -08:00 committed by GitHub
parent a14785af46
commit 96e158eaac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 677 additions and 156 deletions

View file

@ -18,6 +18,12 @@ def pytest_addoption(parser):
default=None,
help="Specify the inference model to use for testing",
)
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config):

View file

@ -9,9 +9,9 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture(
providers=[
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
if "Llama3.1-8B-Instruct" in inference_model:
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture(
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)
return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=inference_model,
contents=texts,
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
embedding_dim = len(response.embeddings[0])
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)

View file

@ -6,9 +6,65 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"memory": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"memory": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"memory": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default=None,
help="Specify the inference model to use for testing",
)
def pytest_configure(config):
for fixture_name in MEMORY_FIXTURES:
config.addinivalue_line(
@ -18,12 +74,22 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if not model:
raise ValueError(
"No inference model specified. Please provide a valid inference model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames:
metafunc.parametrize(
"memory_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in MEMORY_FIXTURES
],
indirect=True,
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -10,6 +10,8 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")
async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
async def memory_stack(inference_model, request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.memory],
{"memory": fixture.providers},
fixture.provider_data,
[Api.memory, Api.inference],
providers,
provider_data,
models=[
ModelInput(
model_id=inference_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
)
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -45,12 +45,14 @@ def sample_documents():
]
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
async def register_memory_bank(
banks_impl: MemoryBanks, inference_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=inference_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):
async def test_banks_list(self, memory_stack, inference_model):
_, banks_impl = memory_stack
# Register a test bank
registered_bank = await register_memory_bank(banks_impl)
registered_bank = await register_memory_bank(banks_impl, inference_model)
try:
# Verify our bank shows up in list
@ -84,7 +86,7 @@ class TestMemory:
)
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack):
async def test_banks_register(self, memory_stack, inference_model):
_, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -94,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=inference_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -109,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
embedding_model=inference_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -126,13 +128,15 @@ class TestMemory:
await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents):
async def test_query_documents(
self, memory_stack, inference_model, sample_documents
):
memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl)
registered_bank = await register_memory_bank(banks_impl, inference_model)
await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents
)
@ -165,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2}
params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5
)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)
assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):