Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query( async def generate_rag_query(
@ -42,7 +45,7 @@ async def default_rag_query_generator(
messages: List[Message], messages: List[Message],
**kwargs, **kwargs,
): ):
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages) return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
async def llm_rag_query_generator( async def llm_rag_query_generator(

View file

@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
async def embeddings( async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia] self, model_id: str, contents: List[InterleavedContent]
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
raise NotImplementedError() raise NotImplementedError()

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaInlineImplConfig from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ChromaInlineImplConfig, _deps): async def get_provider_impl(
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config) impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,9 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = self.cache.get(bank_id) index = self.cache.get(bank_id)

View file

@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
@ -258,18 +261,18 @@ class LlamaGuardShield:
most_recent_img = None most_recent_img = None
for m in messages[::-1]: for m in messages[::-1]:
if isinstance(m.content, str): if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
conversation.append(m) conversation.append(m)
elif isinstance(m.content, ImageMedia): elif isinstance(m.content, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value: if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content most_recent_img = m.content
conversation.append(m) conversation.append(m)
elif isinstance(m.content, list): elif isinstance(m.content, list):
content = [] content = []
for c in m.content: for c in m.content:
if isinstance(c, str): if isinstance(c, str) or isinstance(c, TextContentItem):
content.append(c) content.append(c)
elif isinstance(c, ImageMedia): elif isinstance(c, ImageContentItem):
if most_recent_img is None and m.role == Role.user.value: if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c most_recent_img = c
content.append(c) content.append(c)
@ -292,7 +295,7 @@ class LlamaGuardShield:
categories_str = "\n".join(categories) categories_str = "\n".join(categories)
conversations_str = "\n\n".join( conversations_str = "\n\n".join(
[ [
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}" f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
for m in messages for m in messages
] ]
) )

View file

@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import PromptGuardConfig, PromptGuardType from .config import PromptGuardConfig, PromptGuardType
@ -83,7 +86,7 @@ class PromptGuardShield:
async def run(self, messages: List[Message]) -> RunShieldResponse: async def run(self, messages: List[Message]) -> RunShieldResponse:
message = messages[-1] message = messages[-1]
text = interleaved_text_media_as_str(message.content) text = interleaved_content_as_str(message.content)
# run model on messages and return response # run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt") inputs = self.tokenizer(text, return_tensors="pt")

View file

@ -65,6 +65,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["chromadb"], pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma", module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig", config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,

View file

@ -9,21 +9,24 @@ import json
import uuid import uuid
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
MODEL_ALIASES = [ MODEL_ALIASES = [
@ -64,7 +67,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -449,7 +452,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embeddings = [] embeddings = []
@ -457,7 +460,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
assert not content_has_media( assert not content_has_media(
content content
), "Bedrock does not support media for embeddings" ), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content) input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text} input_body = {"inputText": input_text}
body = json.dumps(input_body) body = json.dumps(input_body)
response = self.client.invoke_model( response = self.client.invoke_model(

View file

@ -69,7 +69,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -166,11 +166,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
raise ValueError("`top_k` not supported by Cerebras") raise ValueError("`top_k` not supported by Cerebras")
prompt = "" prompt = ""
if type(request) == ChatCompletionRequest: if isinstance(request, ChatCompletionRequest):
prompt = chat_completion_request_to_prompt( prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request, self.get_llama_model(request.model), self.formatter
) )
elif type(request) == CompletionRequest: elif isinstance(request, CompletionRequest):
prompt = completion_request_to_prompt(request, self.formatter) prompt = completion_request_to_prompt(request, self.formatter)
else: else:
raise ValueError(f"Unknown request type {type(request)}") raise ValueError(f"Unknown request type {type(request)}")
@ -185,6 +185,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -62,7 +62,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def completion( async def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
@ -135,6 +135,6 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,14 +8,7 @@ import warnings
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
ImageMedia,
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import CoreModelId from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI from openai import APIConnectionError, AsyncOpenAI
@ -28,13 +21,17 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
Inference, Inference,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message,
ResponseFormat, ResponseFormat,
ToolChoice,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig from . import NVIDIAConfig
from .openai_utils import ( from .openai_utils import (
@ -123,17 +120,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if isinstance(content, ImageMedia) or ( if content_has_media(content):
isinstance(content, list) raise NotImplementedError("Media is not supported")
and any(isinstance(c, ImageMedia) for c in content)
):
raise NotImplementedError("ImageMedia is not supported")
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors
@ -165,7 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -6,13 +6,14 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import List from typing import List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -151,7 +152,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)

View file

@ -15,7 +15,7 @@ from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -188,7 +188,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)

View file

@ -13,8 +13,7 @@ from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
@ -160,7 +159,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)

View file

@ -15,6 +15,7 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import MemoryBankType
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
@ -186,7 +187,7 @@ class WeaviateMemoryAdapter(
async def query_documents( async def query_documents(
self, self,
bank_id: str, bank_id: str,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)

View file

@ -81,13 +81,13 @@ def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--inference-model", "--inference-model",
action="store", action="store",
default="meta-llama/Llama-3.1-8B-Instruct", default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing", help="Specify the inference model to use for testing",
) )
parser.addoption( parser.addoption(
"--safety-shield", "--safety-shield",
action="store", action="store",
default="meta-llama/Llama-Guard-3-8B", default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing", help="Specify the safety shield to use for testing",
) )

View file

@ -192,6 +192,19 @@ def inference_tgi() -> ProviderFixture:
) )
@pytest.fixture(scope="session")
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sentence_transformers",
provider_type="inline::sentence-transformers",
config={},
)
]
)
def get_model_short_name(model_name: str) -> str: def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier. """Convert model name to a short test identifier.

View file

@ -15,23 +15,23 @@ from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [ DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param( pytest.param(
{ {
"inference": "meta_reference", "inference": "sentence_transformers",
"memory": "faiss", "memory": "faiss",
}, },
id="meta_reference", id="sentence_transformers",
marks=pytest.mark.meta_reference, marks=pytest.mark.sentence_transformers,
), ),
pytest.param( pytest.param(
{ {
"inference": "ollama", "inference": "ollama",
"memory": "pgvector", "memory": "faiss",
}, },
id="ollama", id="ollama",
marks=pytest.mark.ollama, marks=pytest.mark.ollama,
), ),
pytest.param( pytest.param(
{ {
"inference": "together", "inference": "sentence_transformers",
"memory": "chroma", "memory": "chroma",
}, },
id="chroma", id="chroma",
@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--inference-model", "--embedding-model",
action="store", action="store",
default=None, default=None,
help="Specify the inference model to use for testing", help="Specify the embedding model to use for testing",
) )
@ -74,15 +74,15 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames: if "embedding_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model") model = metafunc.config.getoption("--embedding-model")
if not model: if model:
raise ValueError( params = [pytest.param(model, id="")]
"No inference model specified. Please provide a valid inference model." else:
) params = [pytest.param("all-MiniLM-L6-v2", id="")]
params = [pytest.param(model, id="")]
metafunc.parametrize("embedding_model", params, indirect=True)
metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames: if "memory_stack" in metafunc.fixturenames:
available_fixtures = { available_fixtures = {
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,

View file

@ -24,6 +24,13 @@ from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def memory_remote() -> ProviderFixture: def memory_remote() -> ProviderFixture:
return remote_stack_fixture() return remote_stack_fixture()
@ -107,7 +114,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def memory_stack(inference_model, request): async def memory_stack(embedding_model, request):
fixture_dict = request.param fixture_dict = request.param
providers = {} providers = {}
@ -124,7 +131,7 @@ async def memory_stack(inference_model, request):
provider_data, provider_data,
models=[ models=[
ModelInput( ModelInput(
model_id=inference_model, model_id=embedding_model,
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),

View file

@ -46,13 +46,13 @@ def sample_documents():
async def register_memory_bank( async def register_memory_bank(
banks_impl: MemoryBanks, inference_model: str banks_impl: MemoryBanks, embedding_model: str
) -> MemoryBank: ) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank( return await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=inference_model, embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -61,11 +61,11 @@ async def register_memory_bank(
class TestMemory: class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(self, memory_stack, inference_model): async def test_banks_list(self, memory_stack, embedding_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
# Register a test bank # Register a test bank
registered_bank = await register_memory_bank(banks_impl, inference_model) registered_bank = await register_memory_bank(banks_impl, embedding_model)
try: try:
# Verify our bank shows up in list # Verify our bank shows up in list
@ -86,7 +86,7 @@ class TestMemory:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_register(self, memory_stack, inference_model): async def test_banks_register(self, memory_stack, embedding_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -96,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=inference_model, embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -111,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model=inference_model, embedding_model=embedding_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -129,14 +129,14 @@ class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_documents( async def test_query_documents(
self, memory_stack, inference_model, sample_documents self, memory_stack, embedding_model, sample_documents
): ):
memory_impl, banks_impl = memory_stack memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl, inference_model) registered_bank = await register_memory_bank(banks_impl, embedding_model)
await memory_impl.insert_documents( await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents registered_bank.memory_bank_id, sample_documents
) )

View file

@ -74,7 +74,9 @@ def pytest_addoption(parser):
SAFETY_SHIELD_PARAMS = [ SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), pytest.param(
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
),
] ]
@ -86,6 +88,7 @@ def pytest_generate_tests(metafunc):
if "safety_shield" in metafunc.fixturenames: if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield") shield_id = metafunc.config.getoption("--safety-shield")
if shield_id: if shield_id:
assert shield_id.startswith("meta-llama/")
params = [pytest.param(shield_id, id="")] params = [pytest.param(shield_id, id="")]
else: else:
params = SAFETY_SHIELD_PARAMS params = SAFETY_SHIELD_PARAMS

View file

@ -10,6 +10,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import UserMessage
# How to run this test: # How to run this test:
# #

View file

@ -22,7 +22,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -108,7 +112,7 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else: else:
return r.text return r.text
return interleaved_text_media_as_str(doc.content) return interleaved_content_as_str(doc.content)
def make_overlapped_chunks( def make_overlapped_chunks(
@ -174,7 +178,7 @@ class BankWithIndex:
async def query_documents( async def query_documents(
self, self,
query: InterleavedTextMedia, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryDocumentsResponse:
if params is None: if params is None: