mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
more small cleanup
This commit is contained in:
parent
00c6bbffb7
commit
9c9f5f059a
2 changed files with 12 additions and 8 deletions
|
|
@ -93,7 +93,7 @@ async def get_auto_router_impl(
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
if api == Api.vector_io and run_config.vector_stores:
|
elif api == Api.vector_io:
|
||||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreObject,
|
VectorStoreObject,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.datatypes import VectorStoresConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
|
|
@ -43,7 +44,7 @@ class VectorIORouter(VectorIO):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
vector_stores_config=None,
|
vector_stores_config: VectorStoresConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing VectorIORouter")
|
logger.debug("Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
@ -125,8 +126,11 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id = extra.get("provider_id")
|
provider_id = extra.get("provider_id")
|
||||||
|
|
||||||
# Use default embedding model if not specified
|
# Use default embedding model if not specified
|
||||||
if embedding_model is None and self.vector_stores_config is not None:
|
if (
|
||||||
if self.vector_stores_config.default_embedding_model is not None:
|
embedding_model is None
|
||||||
|
and self.vector_stores_config
|
||||||
|
and self.vector_stores_config.default_embedding_model is not None
|
||||||
|
):
|
||||||
# Construct the full model ID with provider prefix
|
# Construct the full model ID with provider prefix
|
||||||
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
|
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
|
||||||
model_id = self.vector_stores_config.default_embedding_model.model_id
|
model_id = self.vector_stores_config.default_embedding_model.model_id
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue