mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Make embedding generation go through inference (#606)
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
This commit is contained in:
parent
a14785af46
commit
96e158eaac
37 changed files with 677 additions and 156 deletions
|
@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
chunk_size_in_tokens: int
|
chunk_size_in_tokens: int
|
||||||
|
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
llm = "llm"
|
||||||
|
embedding_model = "embedding"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
|
@ -34,12 +40,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
model_type: ModelType = Field(default=ModelType.llm)
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: Optional[str] = None
|
||||||
|
model_type: Optional[ModelType] = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,6 +67,7 @@ class Models(Protocol):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
|
|
|
@ -88,9 +88,10 @@ class InferenceRouter(Inference):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(
|
||||||
model_id, provider_model_id, provider_id, metadata
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.llm:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
|
|
@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -222,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if (
|
||||||
|
"embedding_dimension" not in metadata
|
||||||
|
and model_type == ModelType.embedding_model
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding model must have an embedding dimension in its metadata"
|
||||||
|
)
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -298,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
memory_bank = parse_obj_as(
|
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||||
MemoryBank,
|
if model is None:
|
||||||
{
|
raise ValueError(f"Model {params.embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {params.embedding_model} is not an embedding model"
|
||||||
|
)
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||||
|
)
|
||||||
|
memory_bank_data = {
|
||||||
"identifier": memory_bank_id,
|
"identifier": memory_bank_id,
|
||||||
"type": ResourceType.memory_bank.value,
|
"type": ResourceType.memory_bank.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_memory_bank_id,
|
"provider_resource_id": provider_memory_bank_id,
|
||||||
**params.model_dump(),
|
**params.model_dump(),
|
||||||
},
|
}
|
||||||
)
|
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||||
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
|
"embedding_dimension"
|
||||||
|
]
|
||||||
|
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v2"
|
KEY_VERSION = "v3"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,10 +200,13 @@ API responses, specify the adapter here.
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
ModelRegistryHelper.__init__(
|
if model is None:
|
||||||
self,
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.descriptor(),
|
model.descriptor(),
|
||||||
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.model_registry_helper.register_model(model)
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def embeddings(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
contents: List[InterleavedTextMedia],
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
async def request_with_localized_media(
|
async def request_with_localized_media(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: SentenceTransformersInferenceConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||||
|
|
||||||
|
impl = SentenceTransformersInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceConfig(BaseModel): ...
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||||
|
raise ValueError("Sentence transformers don't support completion")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
raise ValueError("Sentence transformers don't support chat completion")
|
|
@ -4,16 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .faiss import FaissMemoryImpl
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, FaissImplConfig
|
config, FaissImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissMemoryImpl(config)
|
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||||
|
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
if not self.kvstore:
|
if not self.kvstore:
|
||||||
return
|
return
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
stored_data = await self.kvstore.get(index_key)
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
if stored_data:
|
if stored_data:
|
||||||
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
||||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if not self.kvstore or not self.bank_id:
|
if not self.kvstore or not self.bank_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
# Add dimension check
|
||||||
|
embedding_dim = (
|
||||||
|
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||||
|
)
|
||||||
|
if embedding_dim != self.index.d:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.kvstore = None
|
self.kvstore = None
|
||||||
|
|
||||||
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.inference.vllm",
|
module="llama_stack.providers.inline.inference.vllm",
|
||||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
pip_packages=["sentence-transformers"],
|
||||||
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.memory.chroma",
|
||||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.pgvector",
|
module="llama_stack.providers.remote.memory.pgvector",
|
||||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.sample",
|
module="llama_stack.providers.remote.memory.sample",
|
||||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.qdrant",
|
module="llama_stack.providers.remote.memory.qdrant",
|
||||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
import json
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
|
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embeddings = []
|
||||||
|
for content in contents:
|
||||||
|
assert not content_has_media(
|
||||||
|
content
|
||||||
|
), "Bedrock does not support media for embeddings"
|
||||||
|
input_text = interleaved_text_media_as_str(content)
|
||||||
|
input_body = {"inputText": input_text}
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=model.provider_resource_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
embeddings.append(response_body.get("embedding"))
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.fireworks.ai/inference",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[str] = Field(
|
||||||
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_api_key(self) -> str:
|
||||||
fireworks_api_key = None
|
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
fireworks_api_key = self.config.api_key
|
return self.config.api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
fireworks_api_key = provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
def _get_client(self) -> Fireworks:
|
||||||
|
fireworks_api_key = self._get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if model.metadata.get("embedding_dimensions"):
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Fireworks does not support media for embeddings"
|
||||||
|
response = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Ollama does not support media for embeddings"
|
||||||
|
response = await self.client.embed(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
response = await self.client.list()
|
||||||
|
available_models = [m["model"] for m in response["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Together does not support media for embeddings"
|
||||||
|
r = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in r.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
assert model.model_type == ModelType.embedding_model
|
||||||
|
assert model.metadata.get("embedding_dimensions")
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "VLLM does not support media for embeddings"
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -4,12 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaRemoteImplConfig
|
from .config import ChromaRemoteImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
|
async def get_adapter_impl(
|
||||||
|
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
from .chroma import ChromaMemoryAdapter
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
impl = ChromaMemoryAdapter(config)
|
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -13,8 +13,7 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
self,
|
||||||
|
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
||||||
|
inference_api: Api.inference,
|
||||||
) -> None:
|
) -> None:
|
||||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
metadata={"bank": memory_bank.model_dump_json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
collection = await maybe_await(self.client.get_collection(bank_id))
|
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
index = BankWithIndex(
|
||||||
|
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .pgvector import PGVectorMemoryAdapter
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
impl = PGVectorMemoryAdapter(config)
|
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
self,
|
|
||||||
memory_bank: MemoryBank,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
upsert_models(
|
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||||
self.cursor,
|
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||||
[
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
(memory_bank.identifier, memory_bank),
|
memory_bank, index, self.inference_api
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
|
||||||
bank=memory_bank,
|
|
||||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
|
||||||
)
|
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
index = BankWithIndex(
|
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||||
bank=bank,
|
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
return self.cache[bank_id]
|
||||||
)
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
return index
|
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantConfig
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .qdrant import QdrantVectorMemoryAdapter
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
impl = QdrantVectorMemoryAdapter(config)
|
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: QdrantConfig) -> None:
|
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
impl = WeaviateMemoryAdapter(config)
|
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -12,10 +12,11 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self, chunk_ids: List[str]) -> None:
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
collection.data.delete_many(
|
||||||
|
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
Memory,
|
||||||
|
NeedsRequestProviderData,
|
||||||
|
MemoryBanksProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateConfig) -> None:
|
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
|
||||||
memory_bank: MemoryBank,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
|
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|
|
@ -9,9 +9,9 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
|
# If embedding dimension is set, use the 8B model for testing
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
|
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
if "Llama3.1-8B-Instruct" in inference_model:
|
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
|
model_type = ModelType.llm
|
||||||
|
metadata = {}
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
model_type = ModelType.embedding_model
|
||||||
|
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[ModelInput(model_id=inference_model)],
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=inference_model,
|
||||||
|
model_type=model_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddings:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=["Hello, world!"],
|
||||||
|
)
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) > 0
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) == len(texts)
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = len(response.embeddings[0])
|
||||||
|
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
|
@ -6,9 +6,65 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "meta_reference",
|
||||||
|
"memory": "faiss",
|
||||||
|
},
|
||||||
|
id="meta_reference",
|
||||||
|
marks=pytest.mark.meta_reference,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"memory": "pgvector",
|
||||||
|
},
|
||||||
|
id="ollama",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"memory": "chroma",
|
||||||
|
},
|
||||||
|
id="chroma",
|
||||||
|
marks=pytest.mark.chroma,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "bedrock",
|
||||||
|
"memory": "qdrant",
|
||||||
|
},
|
||||||
|
id="qdrant",
|
||||||
|
marks=pytest.mark.qdrant,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"memory": "weaviate",
|
||||||
|
},
|
||||||
|
id="weaviate",
|
||||||
|
marks=pytest.mark.weaviate,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for fixture_name in MEMORY_FIXTURES:
|
for fixture_name in MEMORY_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
model = metafunc.config.getoption("--inference-model")
|
||||||
"memory_stack",
|
if not model:
|
||||||
[
|
raise ValueError(
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
"No inference model specified. Please provide a valid inference model."
|
||||||
for fixture_name in MEMORY_FIXTURES
|
|
||||||
],
|
|
||||||
indirect=True,
|
|
||||||
)
|
)
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
|
metafunc.parametrize("inference_model", params, indirect=True)
|
||||||
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -10,6 +10,8 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
|
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(request):
|
async def memory_stack(inference_model, request):
|
||||||
fixture_name = request.param
|
fixture_dict = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory, Api.inference],
|
||||||
{"memory": fixture.providers},
|
providers,
|
||||||
fixture.provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=inference_model,
|
||||||
|
model_type=ModelType.embedding_model,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
async def register_memory_bank(
|
||||||
|
banks_impl: MemoryBanks, inference_model: str
|
||||||
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -84,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -94,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -109,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -126,13 +128,15 @@ class TestMemory:
|
||||||
await banks_impl.unregister_memory_bank(bank_id)
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(
|
||||||
|
self, memory_stack, inference_model, sample_documents
|
||||||
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
@ -165,13 +169,13 @@ class TestMemory:
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.01}
|
||||||
response5 = await memory_impl.query_documents(
|
response5 = await memory_impl.query_documents(
|
||||||
registered_bank.memory_bank_id, query5, params5
|
registered_bank.memory_bank_id, query5, params5
|
||||||
)
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.01 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||||
|
|
||||||
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embedding_model = self._load_sentence_transformer_model(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
|
embeddings = embedding_model.encode(contents)
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
|
return loaded_model
|
|
@ -9,6 +9,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import (
|
from llama_stack.providers.utils.inference import (
|
||||||
|
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
|
||||||
global EMBEDDING_MODELS
|
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
|
||||||
if loaded_model is not None:
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
bank: VectorMemoryBank
|
bank: VectorMemoryBank
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
|
inference_api: Api.inference
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
|
self.bank.embedding_model, [x.content for x in chunks]
|
||||||
|
)
|
||||||
|
embeddings = np.array(embeddings_response.embeddings)
|
||||||
|
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
self.bank.embedding_model, [query_str]
|
||||||
|
)
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue