mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
implement embedding generation in supported inference providers (#589)
This PR adds the ability to generate embeddings in all supported inference providers. ``` pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py -k "bedrock" --inference-model="amazon.titan-embed-text-v2:0" --env EMBEDDING_DIMENSION=1024 pytest -v -s -k "vllm" --inferrence-model="intfloat/e5-mistral-7b-instruct" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=4096 --env VLLM_URL="http://localhost:9798/v1" pytest -v -s --inference-model="nomic-ai/nomic-embed-text-v1.5" llama_stack/providers/tests/inference/test_embeddings.py -k "fireworks" --env FIREWORKS_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=128 pytest -v -s --inference-model="togethercomputer/m2-bert-80M-2k-retrieval" llama_stack/providers/tests/inference/test_embeddings.py -k "together" --env TOGETHER_API_KEY=<API_KEY>--env EMBEDDING_DIMENSION=768 pytest -v -s -k "ollama" --inference-model="all-minilm:v8" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 ```
This commit is contained in:
parent
6a23f24ee0
commit
d362d2d740
32 changed files with 597 additions and 143 deletions
|
@ -48,7 +48,6 @@ class ModelInput(CommonModelFields):
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: Optional[str] = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
model_type: Optional[ModelType] = ModelType.llm
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,10 +200,13 @@ API responses, specify the adapter here.
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
ModelRegistryHelper.__init__(
|
if model is None:
|
||||||
self,
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.descriptor(),
|
model.descriptor(),
|
||||||
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.model_registry_helper.register_model(model)
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def embeddings(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
contents: List[InterleavedTextMedia],
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
async def request_with_localized_media(
|
async def request_with_localized_media(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: SentenceTransformersInferenceConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||||
|
|
||||||
|
impl = SentenceTransformersInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceConfig(BaseModel): ...
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||||
|
raise ValueError("Sentence transformers don't support completion")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
raise ValueError("Sentence transformers don't support chat completion")
|
|
@ -4,16 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .faiss import FaissMemoryImpl
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, FaissImplConfig
|
config, FaissImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissMemoryImpl(config)
|
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||||
|
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
if not self.kvstore:
|
if not self.kvstore:
|
||||||
return
|
return
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
stored_data = await self.kvstore.get(index_key)
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
if stored_data:
|
if stored_data:
|
||||||
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
||||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if not self.kvstore or not self.bank_id:
|
if not self.kvstore or not self.bank_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
# Add dimension check
|
||||||
|
embedding_dim = (
|
||||||
|
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||||
|
)
|
||||||
|
if embedding_dim != self.index.d:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.kvstore = None
|
self.kvstore = None
|
||||||
|
|
||||||
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.inference.vllm",
|
module="llama_stack.providers.inline.inference.vllm",
|
||||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
pip_packages=["sentence-transformers"],
|
||||||
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.memory.chroma",
|
||||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.pgvector",
|
module="llama_stack.providers.remote.memory.pgvector",
|
||||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.sample",
|
module="llama_stack.providers.remote.memory.sample",
|
||||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.qdrant",
|
module="llama_stack.providers.remote.memory.qdrant",
|
||||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
import json
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
|
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embeddings = []
|
||||||
|
for content in contents:
|
||||||
|
assert not content_has_media(
|
||||||
|
content
|
||||||
|
), "Bedrock does not support media for embeddings"
|
||||||
|
input_text = interleaved_text_media_as_str(content)
|
||||||
|
input_body = {"inputText": input_text}
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=model.provider_resource_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
embeddings.append(response_body.get("embedding"))
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.fireworks.ai/inference",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[str] = Field(
|
||||||
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_api_key(self) -> str:
|
||||||
fireworks_api_key = None
|
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
fireworks_api_key = self.config.api_key
|
return self.config.api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
fireworks_api_key = provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
def _get_client(self) -> Fireworks:
|
||||||
|
fireworks_api_key = self._get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if model.metadata.get("embedding_dimensions"):
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Fireworks does not support media for embeddings"
|
||||||
|
response = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Ollama does not support media for embeddings"
|
||||||
|
response = await self.client.embed(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
response = await self.client.list()
|
||||||
|
available_models = [m["model"] for m in response["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Together does not support media for embeddings"
|
||||||
|
r = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in r.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
assert model.model_type == ModelType.embedding_model
|
||||||
|
assert model.metadata.get("embedding_dimensions")
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "VLLM does not support media for embeddings"
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -4,12 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from .config import ChromaRemoteImplConfig
|
from .config import ChromaRemoteImplConfig
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
|
async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .chroma import ChromaMemoryAdapter
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
impl = ChromaMemoryAdapter(config)
|
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -13,8 +13,7 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
self,
|
||||||
|
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
||||||
|
inference_api: Api.inference,
|
||||||
) -> None:
|
) -> None:
|
||||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
metadata={"bank": memory_bank.model_dump_json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
collection = await maybe_await(self.client.get_collection(bank_id))
|
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
index = BankWithIndex(
|
||||||
|
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .pgvector import PGVectorMemoryAdapter
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
impl = PGVectorMemoryAdapter(config)
|
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
self,
|
|
||||||
memory_bank: MemoryBank,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
upsert_models(
|
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||||
self.cursor,
|
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||||
[
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
(memory_bank.identifier, memory_bank),
|
memory_bank, index, self.inference_api
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
|
||||||
bank=memory_bank,
|
|
||||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
|
||||||
)
|
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
index = BankWithIndex(
|
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||||
bank=bank,
|
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
return self.cache[bank_id]
|
||||||
)
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
return index
|
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantConfig
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .qdrant import QdrantVectorMemoryAdapter
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
impl = QdrantVectorMemoryAdapter(config)
|
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: QdrantConfig) -> None:
|
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
impl = WeaviateMemoryAdapter(config)
|
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -12,10 +12,11 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self, chunk_ids: List[str]) -> None:
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
collection.data.delete_many(
|
||||||
|
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
Memory,
|
||||||
|
NeedsRequestProviderData,
|
||||||
|
MemoryBanksProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateConfig) -> None:
|
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
|
||||||
memory_bank: MemoryBank,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
|
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|
|
@ -9,9 +9,9 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
|
# If embedding dimension is set, use the 8B model for testing
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
|
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
if "Llama3.1-8B-Instruct" in inference_model:
|
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
|
model_type = ModelType.llm
|
||||||
|
metadata = {}
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
model_type = ModelType.embedding_model
|
||||||
|
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[ModelInput(model_id=inference_model)],
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=inference_model,
|
||||||
|
model_type=model_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddings:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=["Hello, world!"],
|
||||||
|
)
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) > 0
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) == len(texts)
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = len(response.embeddings[0])
|
||||||
|
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
|
@ -6,9 +6,65 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "meta_reference",
|
||||||
|
"memory": "faiss",
|
||||||
|
},
|
||||||
|
id="meta_reference",
|
||||||
|
marks=pytest.mark.meta_reference,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"memory": "pgvector",
|
||||||
|
},
|
||||||
|
id="ollama",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"memory": "chroma",
|
||||||
|
},
|
||||||
|
id="chroma",
|
||||||
|
marks=pytest.mark.chroma,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "bedrock",
|
||||||
|
"memory": "qdrant",
|
||||||
|
},
|
||||||
|
id="qdrant",
|
||||||
|
marks=pytest.mark.qdrant,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"memory": "weaviate",
|
||||||
|
},
|
||||||
|
id="weaviate",
|
||||||
|
marks=pytest.mark.weaviate,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for fixture_name in MEMORY_FIXTURES:
|
for fixture_name in MEMORY_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
model = metafunc.config.getoption("--embedding-model")
|
||||||
"memory_stack",
|
if not model:
|
||||||
[
|
raise ValueError(
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
"No embedding model specified. Please provide a valid embedding model."
|
||||||
for fixture_name in MEMORY_FIXTURES
|
|
||||||
],
|
|
||||||
indirect=True,
|
|
||||||
)
|
)
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
|
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||||
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -10,6 +10,8 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
|
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(request):
|
async def memory_stack(embedding_model, request):
|
||||||
fixture_name = request.param
|
fixture_dict = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory, Api.inference],
|
||||||
{"memory": fixture.providers},
|
providers,
|
||||||
fixture.provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=embedding_model,
|
||||||
|
model_type=ModelType.embedding_model,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
async def register_memory_bank(
|
||||||
|
banks_impl: MemoryBanks, embedding_model: str
|
||||||
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -84,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack, embedding_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -94,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -109,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=embedding_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -126,13 +128,15 @@ class TestMemory:
|
||||||
await banks_impl.unregister_memory_bank(bank_id)
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(
|
||||||
|
self, memory_stack, embedding_model, sample_documents
|
||||||
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, embedding_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
@ -165,13 +169,13 @@ class TestMemory:
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.01}
|
||||||
response5 = await memory_impl.query_documents(
|
response5 = await memory_impl.query_documents(
|
||||||
registered_bank.memory_bank_id, query5, params5
|
registered_bank.memory_bank_id, query5, params5
|
||||||
)
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.01 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||||
|
|
||||||
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embedding_model = self._load_sentence_transformer_model(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
|
embeddings = embedding_model.encode(contents)
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
|
return loaded_model
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
|
||||||
global EMBEDDING_MODELS
|
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
|
||||||
if loaded_model is not None:
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
bank: VectorMemoryBank
|
bank: VectorMemoryBank
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
|
inference_api: Api.inference
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
|
self.bank.embedding_model, [x.content for x in chunks]
|
||||||
|
)
|
||||||
|
embeddings = np.array(embeddings_response.embeddings)
|
||||||
|
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
self.bank.embedding_model, [query_str]
|
||||||
|
)
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue