mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-14 13:02:36 +00:00
chore: Updating how default embedding model is set in stack
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> # Conflicts: # .github/workflows/integration-vector-io-tests.yml # llama_stack/distributions/ci-tests/run.yaml # llama_stack/distributions/starter-gpu/run.yaml # llama_stack/distributions/starter/run.yaml # llama_stack/distributions/template.py # llama_stack/providers/utils/memory/openai_vector_store_mixin.py
This commit is contained in:
parent
cd152f4240
commit
24a1430c8b
32 changed files with 276 additions and 265 deletions
|
|
@ -59,7 +59,6 @@ class SentenceTransformersInferenceImpl(
|
|||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"default_configured": True,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -6,21 +6,29 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: ChromaVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = ChromaVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,21 +6,29 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: FaissVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = FaissVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
|
|
@ -206,11 +207,13 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -6,19 +6,27 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: MilvusVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
impl = MilvusVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.models),
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,20 +6,28 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: QdrantVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = QdrantVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,20 +6,28 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||
async def get_provider_impl(
|
||||
config: SQLiteVectorIOConfig, deps: dict[Api, Any], run_config: StackRunConfig | None = None
|
||||
):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
vector_stores_config = None
|
||||
if run_config and run_config.vector_stores:
|
||||
vector_stores_config = run_config.vector_stores
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SQLiteVecVectorIOAdapter(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
deps[Api.models],
|
||||
deps.get(Api.files),
|
||||
vector_stores_config,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -416,11 +417,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
inference_api: Inference,
|
||||
models_api: Models,
|
||||
files_api: Files | None,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(files_api=files_api, kvstore=None)
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.models_api = models_api
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.vector_db_store = None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue