Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
@ -42,7 +45,7 @@ async def default_rag_query_generator(
messages: List[Message],
**kwargs,
):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator(

View file

@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia]
self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError()

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
async def get_provider_impl(
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config)
impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -19,9 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)

View file

@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig
@ -258,18 +261,18 @@ class LlamaGuardShield:
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c)
elif isinstance(c, ImageMedia):
elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
@ -292,7 +295,7 @@ class LlamaGuardShield:
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages
]
)

View file

@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import PromptGuardConfig, PromptGuardType
@ -83,7 +86,7 @@ class PromptGuardShield:
async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1]
text = interleaved_text_media_as_str(message.content)
text = interleaved_content_as_str(message.content)
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")

View file

@ -65,6 +65,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,

View file

@ -9,21 +9,24 @@ import json
import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
MODEL_ALIASES = [
@ -64,7 +67,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -449,7 +452,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
@ -457,7 +460,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(

View file

@ -69,7 +69,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -166,11 +166,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
if isinstance(request, ChatCompletionRequest):
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
elif isinstance(request, CompletionRequest):
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
@ -185,6 +185,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -62,7 +62,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -135,6 +135,6 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,14 +8,7 @@ import warnings
from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
ImageMedia,
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI
@ -28,13 +21,17 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
from .openai_utils import (
@ -123,17 +120,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if isinstance(content, ImageMedia) or (
isinstance(content, list)
and any(isinstance(c, ImageMedia) for c in content)
):
raise NotImplementedError("ImageMedia is not supported")
if content_has_media(content):
raise NotImplementedError("Media is not supported")
await check_health(self._config) # this raises errors
@ -165,7 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -6,13 +6,14 @@
import asyncio
import json
import logging
from typing import List
from typing import List, Optional, Union
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -151,7 +152,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)

View file

@ -15,7 +15,7 @@ from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
@ -188,7 +188,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)

View file

@ -13,8 +13,7 @@ from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
@ -160,7 +159,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)

View file

@ -15,6 +15,7 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
@ -186,7 +187,7 @@ class WeaviateMemoryAdapter(
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)

View file

@ -81,13 +81,13 @@ def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-8B",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)

View file

@ -192,6 +192,19 @@ def inference_tgi() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.

View file

@ -15,23 +15,23 @@ from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"inference": "sentence_transformers",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
id="sentence_transformers",
marks=pytest.mark.sentence_transformers,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
"memory": "faiss",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"inference": "sentence_transformers",
"memory": "chroma",
},
id="chroma",
@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
"--embedding-model",
action="store",
default=None,
help="Specify the inference model to use for testing",
help="Specify the embedding model to use for testing",
)
@ -74,15 +74,15 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if not model:
raise ValueError(
"No inference model specified. Please provide a valid inference model."
)
if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--embedding-model")
if model:
params = [pytest.param(model, id="")]
else:
params = [pytest.param("all-MiniLM-L6-v2", id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,

View file

@ -24,6 +24,13 @@ from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def memory_remote() -> ProviderFixture:
return remote_stack_fixture()
@ -107,7 +114,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")
async def memory_stack(inference_model, request):
async def memory_stack(embedding_model, request):
fixture_dict = request.param
providers = {}
@ -124,7 +131,7 @@ async def memory_stack(inference_model, request):
provider_data,
models=[
ModelInput(
model_id=inference_model,
model_id=embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),

View file

@ -46,13 +46,13 @@ def sample_documents():
async def register_memory_bank(
banks_impl: MemoryBanks, inference_model: str
banks_impl: MemoryBanks, embedding_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model=inference_model,
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -61,11 +61,11 @@ async def register_memory_bank(
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack, inference_model):
async def test_banks_list(self, memory_stack, embedding_model):
_, banks_impl = memory_stack
# Register a test bank
registered_bank = await register_memory_bank(banks_impl, inference_model)
registered_bank = await register_memory_bank(banks_impl, embedding_model)
try:
# Verify our bank shows up in list
@ -86,7 +86,7 @@ class TestMemory:
)
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack, inference_model):
async def test_banks_register(self, memory_stack, embedding_model):
_, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -96,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model=inference_model,
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -111,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model=inference_model,
embedding_model=embedding_model,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -129,14 +129,14 @@ class TestMemory:
@pytest.mark.asyncio
async def test_query_documents(
self, memory_stack, inference_model, sample_documents
self, memory_stack, embedding_model, sample_documents
):
memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl, inference_model)
registered_bank = await register_memory_bank(banks_impl, embedding_model)
await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents
)

View file

@ -74,7 +74,9 @@ def pytest_addoption(parser):
SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
pytest.param(
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
),
]
@ -86,6 +88,7 @@ def pytest_generate_tests(metafunc):
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
assert shield_id.startswith("meta-llama/")
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_SHIELD_PARAMS

View file

@ -10,6 +10,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import UserMessage
# How to run this test:
#

View file

@ -22,7 +22,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
log = logging.getLogger(__name__)
@ -108,7 +112,7 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
return r.text
return interleaved_text_media_as_str(doc.content)
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(
@ -174,7 +178,7 @@ class BankWithIndex:
async def query_documents(
self,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
if params is None: