mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Memory tests pass now
This commit is contained in:
parent
e51154964f
commit
59ce047aea
23 changed files with 122 additions and 81 deletions
|
@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
|
@ -42,7 +45,7 @@ async def default_rag_query_generator(
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
|
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
|
|
|
@ -114,7 +114,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -218,8 +218,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self, model_id: str, contents: list[InterleavedTextMedia]
|
self, model_id: str, contents: List[InterleavedContent]
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
log.info("vLLM embeddings")
|
|
||||||
# TODO
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -4,12 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaInlineImplConfig
|
from .config import ChromaInlineImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
|
async def get_provider_impl(
|
||||||
|
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
|
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
impl = ChromaMemoryAdapter(config)
|
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,9 +19,10 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = self.cache.get(bank_id)
|
index = self.cache.get(bank_id)
|
||||||
|
|
|
@ -15,6 +15,9 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import LlamaGuardConfig
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
@ -258,18 +261,18 @@ class LlamaGuardShield:
|
||||||
most_recent_img = None
|
most_recent_img = None
|
||||||
|
|
||||||
for m in messages[::-1]:
|
for m in messages[::-1]:
|
||||||
if isinstance(m.content, str):
|
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
|
||||||
conversation.append(m)
|
conversation.append(m)
|
||||||
elif isinstance(m.content, ImageMedia):
|
elif isinstance(m.content, ImageContentItem):
|
||||||
if most_recent_img is None and m.role == Role.user.value:
|
if most_recent_img is None and m.role == Role.user.value:
|
||||||
most_recent_img = m.content
|
most_recent_img = m.content
|
||||||
conversation.append(m)
|
conversation.append(m)
|
||||||
elif isinstance(m.content, list):
|
elif isinstance(m.content, list):
|
||||||
content = []
|
content = []
|
||||||
for c in m.content:
|
for c in m.content:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||||
content.append(c)
|
content.append(c)
|
||||||
elif isinstance(c, ImageMedia):
|
elif isinstance(c, ImageContentItem):
|
||||||
if most_recent_img is None and m.role == Role.user.value:
|
if most_recent_img is None and m.role == Role.user.value:
|
||||||
most_recent_img = c
|
most_recent_img = c
|
||||||
content.append(c)
|
content.append(c)
|
||||||
|
@ -292,7 +295,7 @@ class LlamaGuardShield:
|
||||||
categories_str = "\n".join(categories)
|
categories_str = "\n".join(categories)
|
||||||
conversations_str = "\n\n".join(
|
conversations_str = "\n\n".join(
|
||||||
[
|
[
|
||||||
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
|
@ -83,7 +86,7 @@ class PromptGuardShield:
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||||
message = messages[-1]
|
message = messages[-1]
|
||||||
text = interleaved_text_media_as_str(message.content)
|
text = interleaved_content_as_str(message.content)
|
||||||
|
|
||||||
# run model on messages and return response
|
# run model on messages and return response
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
|
|
|
@ -65,6 +65,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
||||||
module="llama_stack.providers.inline.memory.chroma",
|
module="llama_stack.providers.inline.memory.chroma",
|
||||||
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
|
|
@ -9,21 +9,24 @@ import json
|
||||||
import uuid
|
import uuid
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
content_has_media,
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
|
@ -64,7 +67,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -449,7 +452,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -457,7 +460,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
assert not content_has_media(
|
assert not content_has_media(
|
||||||
content
|
content
|
||||||
), "Bedrock does not support media for embeddings"
|
), "Bedrock does not support media for embeddings"
|
||||||
input_text = interleaved_text_media_as_str(content)
|
input_text = interleaved_content_as_str(content)
|
||||||
input_body = {"inputText": input_text}
|
input_body = {"inputText": input_text}
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
|
|
|
@ -69,7 +69,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -166,11 +166,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
raise ValueError("`top_k` not supported by Cerebras")
|
raise ValueError("`top_k` not supported by Cerebras")
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
if type(request) == ChatCompletionRequest:
|
if isinstance(request, ChatCompletionRequest):
|
||||||
prompt = chat_completion_request_to_prompt(
|
prompt = chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
elif type(request) == CompletionRequest:
|
elif isinstance(request, CompletionRequest):
|
||||||
prompt = completion_request_to_prompt(request, self.formatter)
|
prompt = completion_request_to_prompt(request, self.formatter)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
@ -185,6 +185,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -62,7 +62,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -135,6 +135,6 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,14 +8,7 @@ import warnings
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import SamplingParams
|
from llama_models.datatypes import SamplingParams
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
ImageMedia,
|
|
||||||
InterleavedTextMedia,
|
|
||||||
Message,
|
|
||||||
ToolChoice,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_models.sku_list import CoreModelId
|
from llama_models.sku_list import CoreModelId
|
||||||
from openai import APIConnectionError, AsyncOpenAI
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
|
||||||
|
@ -28,13 +21,17 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
Inference,
|
Inference,
|
||||||
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
from .openai_utils import (
|
from .openai_utils import (
|
||||||
|
@ -123,17 +120,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||||
if isinstance(content, ImageMedia) or (
|
if content_has_media(content):
|
||||||
isinstance(content, list)
|
raise NotImplementedError("Media is not supported")
|
||||||
and any(isinstance(c, ImageMedia) for c in content)
|
|
||||||
):
|
|
||||||
raise NotImplementedError("ImageMedia is not supported")
|
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
|
@ -165,7 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBankType
|
||||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -151,7 +152,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from psycopg2.extras import execute_values, Json
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -188,7 +188,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
|
|
@ -13,8 +13,7 @@ from qdrant_client import AsyncQdrantClient, models
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
|
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
|
||||||
|
@ -160,7 +159,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBankType
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -186,7 +187,7 @@ class WeaviateMemoryAdapter(
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
|
|
@ -81,13 +81,13 @@ def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--inference-model",
|
"--inference-model",
|
||||||
action="store",
|
action="store",
|
||||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-shield",
|
"--safety-shield",
|
||||||
action="store",
|
action="store",
|
||||||
default="meta-llama/Llama-Guard-3-8B",
|
default="meta-llama/Llama-Guard-3-1B",
|
||||||
help="Specify the safety shield to use for testing",
|
help="Specify the safety shield to use for testing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -192,6 +192,19 @@ def inference_tgi() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_sentence_transformers() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="sentence_transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_short_name(model_name: str) -> str:
|
def get_model_short_name(model_name: str) -> str:
|
||||||
"""Convert model name to a short test identifier.
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
|
|
@ -15,23 +15,23 @@ from .fixtures import MEMORY_FIXTURES
|
||||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "meta_reference",
|
"inference": "sentence_transformers",
|
||||||
"memory": "faiss",
|
"memory": "faiss",
|
||||||
},
|
},
|
||||||
id="meta_reference",
|
id="sentence_transformers",
|
||||||
marks=pytest.mark.meta_reference,
|
marks=pytest.mark.sentence_transformers,
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "ollama",
|
"inference": "ollama",
|
||||||
"memory": "pgvector",
|
"memory": "faiss",
|
||||||
},
|
},
|
||||||
id="ollama",
|
id="ollama",
|
||||||
marks=pytest.mark.ollama,
|
marks=pytest.mark.ollama,
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "together",
|
"inference": "sentence_transformers",
|
||||||
"memory": "chroma",
|
"memory": "chroma",
|
||||||
},
|
},
|
||||||
id="chroma",
|
id="chroma",
|
||||||
|
@ -58,10 +58,10 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--inference-model",
|
"--embedding-model",
|
||||||
action="store",
|
action="store",
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the embedding model to use for testing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,15 +74,15 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--inference-model")
|
model = metafunc.config.getoption("--embedding-model")
|
||||||
if not model:
|
if model:
|
||||||
raise ValueError(
|
params = [pytest.param(model, id="")]
|
||||||
"No inference model specified. Please provide a valid inference model."
|
else:
|
||||||
)
|
params = [pytest.param("all-MiniLM-L6-v2", id="")]
|
||||||
params = [pytest.param(model, id="")]
|
|
||||||
|
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||||
|
|
||||||
metafunc.parametrize("inference_model", params, indirect=True)
|
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
"inference": INFERENCE_FIXTURES,
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
|
|
@ -24,6 +24,13 @@ from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def embedding_model(request):
|
||||||
|
if hasattr(request, "param"):
|
||||||
|
return request.param
|
||||||
|
return request.config.getoption("--embedding-model", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def memory_remote() -> ProviderFixture:
|
def memory_remote() -> ProviderFixture:
|
||||||
return remote_stack_fixture()
|
return remote_stack_fixture()
|
||||||
|
@ -107,7 +114,7 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(inference_model, request):
|
async def memory_stack(embedding_model, request):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
|
@ -124,7 +131,7 @@ async def memory_stack(inference_model, request):
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=inference_model,
|
model_id=embedding_model,
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={
|
metadata={
|
||||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
|
|
@ -46,13 +46,13 @@ def sample_documents():
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
banks_impl: MemoryBanks, inference_model: str
|
banks_impl: MemoryBanks, embedding_model: str
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=inference_model,
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -61,11 +61,11 @@ async def register_memory_bank(
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack, inference_model):
|
async def test_banks_list(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -86,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack, inference_model):
|
async def test_banks_register(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -96,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=inference_model,
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -111,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model=inference_model,
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -129,14 +129,14 @@ class TestMemory:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(
|
async def test_query_documents(
|
||||||
self, memory_stack, inference_model, sample_documents
|
self, memory_stack, embedding_model, sample_documents
|
||||||
):
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,7 +74,9 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_PARAMS = [
|
SAFETY_SHIELD_PARAMS = [
|
||||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
pytest.param(
|
||||||
|
"meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,6 +88,7 @@ def pytest_generate_tests(metafunc):
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
shield_id = metafunc.config.getoption("--safety-shield")
|
shield_id = metafunc.config.getoption("--safety-shield")
|
||||||
if shield_id:
|
if shield_id:
|
||||||
|
assert shield_id.startswith("meta-llama/")
|
||||||
params = [pytest.param(shield_id, id="")]
|
params = [pytest.param(shield_id, id="")]
|
||||||
else:
|
else:
|
||||||
params = SAFETY_SHIELD_PARAMS
|
params = SAFETY_SHIELD_PARAMS
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
|
|
@ -22,7 +22,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -108,7 +112,7 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
else:
|
else:
|
||||||
return r.text
|
return r.text
|
||||||
|
|
||||||
return interleaved_text_media_as_str(doc.content)
|
return interleaved_content_as_str(doc.content)
|
||||||
|
|
||||||
|
|
||||||
def make_overlapped_chunks(
|
def make_overlapped_chunks(
|
||||||
|
@ -174,7 +178,7 @@ class BankWithIndex:
|
||||||
|
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
if params is None:
|
if params is None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue