Make embedding generation go through inference (#606)

This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:47:50 -08:00 committed by GitHub
parent a14785af46
commit 96e158eaac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 677 additions and 156 deletions

View file

@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None overlap_size_in_tokens: Optional[int] = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
) )
class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -34,12 +40,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -59,6 +67,7 @@ class Models(Protocol):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ... ) -> Model: ...
@webmethod(route="/models/unregister", method="POST") @webmethod(route="/models/unregister", method="POST")

View file

@ -88,9 +88,10 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None: ) -> None:
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata model_id, provider_model_id, provider_id, metadata, model_type
) )
async def chat_completion( async def chat_completion(
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,

View file

@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
memory_bank = parse_obj_as( model = await self.get_object_by_identifier("model", params.embedding_model)
MemoryBank, if model is None:
{ raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id, "identifier": memory_bank_id,
"type": ResourceType.memory_bank.value, "type": ResourceType.memory_bank.value,
"provider_id": provider_id, "provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id, "provider_resource_id": provider_memory_bank_id,
**params.model_dump(), **params.model_dump(),
}, }
) if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank) await self.register_object(memory_bank)
return memory_bank return memory_bank

View file

@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2" KEY_VERSION = "v3"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -200,10 +200,13 @@ API responses, specify the adapter here.
return self.adapter.provider_data_validator return self.adapter.provider_data_validator
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
) -> RemoteProviderSpec:
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, api=api,
provider_type=f"remote::{adapter.adapter_type}", provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class, config_class=adapter.config_class,
adapter=adapter, adapter=adapter,
api_dependencies=api_dependencies or [],
) )

View file

@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url, convert_image_media_to_url,
request_has_media, request_has_media,
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
ModelRegistryHelper.__init__( if model is None:
self, raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model_registry_helper = ModelRegistryHelper(
[ [
build_model_alias( build_model_alias(
model.descriptor(), model.descriptor(),
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
) )
], ],
) )
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
for x in impl(): for x in impl():
yield x yield x
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media( async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
):
from .sentence_transformers import SentenceTransformersInferenceImpl
impl = SentenceTransformersInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel): ...

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
raise ValueError("Sentence transformers don't support completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")

View file

@ -4,16 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps): async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissMemoryImpl from .faiss import FaissMemoryImpl
assert isinstance( assert isinstance(
config, FaissImplConfig config, FaissImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config) impl = FaissMemoryImpl(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,11 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -32,7 +31,8 @@ from .config import FaissImplConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::" MEMORY_BANKS_PREFIX = "memory_banks:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore: if not self.kvstore:
return return
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
stored_data = await self.kvstore.get(index_key) stored_data = await self.kvstore.get(index_key)
if stored_data: if stored_data:
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data)) await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self): async def delete(self):
if not self.kvstore or not self.bank_id: if not self.kvstore or not self.bank_id:
return return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index) indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cache = {} self.cache = {}
self.kvstore = None self.kvstore = None
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier bank.embedding_dimension, self.kvstore, bank.identifier
), ),
self.inference_api,
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) )
# Store in cache # Store in cache
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
), ),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]

View file

@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
"transformers", "transformers",
"zmq", "zmq",
"lm-format-enforcer", "lm-format-enforcer",
"sentence-transformers",
] ]
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.vllm", module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
pip_packages=["sentence-transformers"],
module="llama_stack.providers.inline.inference.sentence_transformers",
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
), ),
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector", module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.memory, api=Api.memory,
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample", module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig", config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
), ),
api_dependencies=[],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant", module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
), ),
api_dependencies=[Api.inference],
), ),
] ]

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import * # noqa: F403 from typing import * # noqa: F403
import json
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [ model_aliases = [
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class FireworksImplConfig(BaseModel): class FireworksImplConfig(BaseModel):
url: str = Field( url: str = Field(
default="https://api.fireworks.ai/inference", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: Optional[str] = Field( api_key: Optional[str] = Field(
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls) -> Dict[str, Any]: def sample_run_config(cls) -> Dict[str, Any]:
return { return {
"url": "https://api.fireworks.ai/inference", "url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}", "api_key": "${env.FIREWORKS_API_KEY}",
} }

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_client(self) -> Fireworks: def _get_api_key(self) -> str:
fireworks_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
fireworks_api_key = self.config.api_key return self.config.api_key
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key: if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError( raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}' 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
) )
fireworks_api_key = provider_data.fireworks_api_key return provider_data.fireworks_api_key
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key) return Fireworks(api_key=fireworks_api_key)
async def completion( async def completion(
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_image_media_to_url, convert_image_media_to_url,
request_has_media, request_has_media,
) )
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Ollama does not support media for embeddings"
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding_model:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
models = await self.client.ps() models = await self.client.ps()
available_models = [m["model"] for m in models["models"]] available_models = [m["model"] for m in models["models"]]

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Together does not support media for embeddings"
r = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding_model
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaRemoteImplConfig from .config import ChromaRemoteImplConfig
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): async def get_adapter_impl(
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config) impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,8 +13,7 @@ import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__( def __init__(
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None: ) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}") log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
metadata={"bank": memory_bank.model_dump_json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
) )
bank_index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) memory_bank, ChromaIndex(self.client, collection), self.inference_api
) )
self.cache[memory_bank.identifier] = bank_index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await maybe_await(self.client.get_collection(bank_id)) collection = await maybe_await(self.client.get_collection(bank_id))
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps): async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config) impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cursor = None self.cursor = None
self.conn = None self.conn = None
self.cache = {} self.cache = {}
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank( async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
self,
memory_bank: MemoryBank,
) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models( upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
self.cursor, index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
[ self.cache[memory_bank.identifier] = BankWithIndex(
(memory_bank.identifier, memory_bank), memory_bank, index, self.inference_api
],
) )
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params) return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex( index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
bank=bank, self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), return self.cache[bank_id]
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps): async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config) impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None: def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {} self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier), index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id), index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config) impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
): ):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client() client = self._get_client()
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
], ],
) )
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id), index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -18,6 +18,12 @@ def pytest_addoption(parser):
default=None, default=None,
help="Specify the inference model to use for testing", help="Specify the inference model to use for testing",
) )
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config): def pytest_configure(config):

View file

@ -9,9 +9,9 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model [inference_model] if isinstance(inference_model, str) else inference_model
) )
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model [inference_model] if isinstance(inference_model, str) else inference_model
) )
if "Llama3.1-8B-Instruct" in inference_model: if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture( return ProviderFixture(
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model): async def inference_stack(request, inference_model):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.inference], [Api.inference],
{"inference": inference_fixture.providers}, {"inference": inference_fixture.providers},
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)], models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
) )
return test_stack.impls[Api.inference], test_stack.impls[Api.models] return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
pytest.skip("This test is only applicable for embedding models")
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=inference_model,
contents=texts,
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
embedding_dim = len(response.embeddings[0])
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)

View file

@ -6,9 +6,65 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"memory": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"memory": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"memory": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default=None,
help="Specify the inference model to use for testing",
)
def pytest_configure(config): def pytest_configure(config):
for fixture_name in MEMORY_FIXTURES: for fixture_name in MEMORY_FIXTURES:
config.addinivalue_line( config.addinivalue_line(
@ -18,12 +74,22 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "memory_stack" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
metafunc.parametrize( model = metafunc.config.getoption("--inference-model")
"memory_stack", if not model:
[ raise ValueError(
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) "No inference model specified. Please provide a valid inference model."
for fixture_name in MEMORY_FIXTURES
],
indirect=True,
) )
params = [pytest.param(model, id="")]
metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -10,6 +10,8 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def memory_stack(request): async def memory_stack(inference_model, request):
fixture_name = request.param fixture_dict = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
providers = {}
provider_data = {}
for key in ["inference", "memory"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory, Api.inference],
{"memory": fixture.providers}, providers,
fixture.provider_data, provider_data,
models=[
ModelInput(
model_id=inference_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
) )
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -45,12 +45,14 @@ def sample_documents():
] ]
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: async def register_memory_bank(
banks_impl: MemoryBanks, inference_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank( return await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
class TestMemory: class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(self, memory_stack): async def test_banks_list(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
# Register a test bank # Register a test bank
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, inference_model)
try: try:
# Verify our bank shows up in list # Verify our bank shows up in list
@ -84,7 +86,7 @@ class TestMemory:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_register(self, memory_stack): async def test_banks_register(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -94,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -109,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -126,13 +128,15 @@ class TestMemory:
await banks_impl.unregister_memory_bank(bank_id) await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents): async def test_query_documents(
self, memory_stack, inference_model, sample_documents
):
memory_impl, banks_impl = memory_stack memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, inference_model)
await memory_impl.insert_documents( await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents registered_bank.memory_bank_id, sample_documents
) )
@ -165,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2} params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents( response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5 registered_bank.memory_bank_id, query5, params5
) )
assert_valid_response(response5) assert_valid_response(response5)
print("The scores are:", response5.scores) print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores) assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse): def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import List
from llama_models.llama3.api.datatypes import InterleavedTextMedia
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(
model.provider_resource_id
)
embeddings = embedding_model.encode(contents)
return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model

View file

@ -9,6 +9,7 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None return None
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id) if model.model_type == ModelType.embedding_model:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else:

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex: class BankWithIndex:
bank: VectorMemoryBank bank: VectorMemoryBank
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents( async def insert_documents(
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
) )
if not chunks: if not chunks:
continue continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32) embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,8 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model) embeddings_response = await self.inference_api.embeddings(
query_vector = model.encode([query_str])[0].astype(np.float32) self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold) return await self.index.query(query_vector, k, score_threshold)