mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 06:07:43 +00:00
feat(registry): more flexible model lookup (#2859)
This PR updates model registration and lookup behavior to be slightly more general / flexible. See https://github.com/meta-llama/llama-stack/issues/2843 for more details. Note that this change is backwards compatible given the design of the `lookup_model()` method. ## Test Plan Added unit tests
This commit is contained in:
parent
9736f096f6
commit
3b83032555
15 changed files with 265 additions and 75 deletions
|
@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
|
|||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
start_index=start_index,
|
||||
limit=limit,
|
||||
|
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
|
|||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
|
|
@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
|
|||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
|
|||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
score_response = await provider.score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
|
@ -97,7 +99,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
|
@ -110,7 +113,8 @@ class EvalRouter(Eval):
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
|
@ -123,7 +127,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_status(benchmark_id, job_id)
|
||||
|
||||
async def job_cancel(
|
||||
self,
|
||||
|
@ -131,7 +136,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
await provider.job_cancel(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
@ -142,7 +148,8 @@ class EvalRouter(Eval):
|
|||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_result(
|
||||
benchmark_id,
|
||||
job_id,
|
||||
)
|
||||
|
|
|
@ -231,7 +231,7 @@ class InferenceRouter(Inference):
|
|||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
|
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
|
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
|
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
|
|||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
|
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
|
|||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
|
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
|
|||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(**params)
|
||||
|
||||
async def list_chat_completions(
|
||||
|
|
|
@ -50,7 +50,8 @@ class SafetyRouter(Safety):
|
|||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
|
|
|
@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
|
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
|
|
@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
|
|||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
async def openai_create_vector_store(
|
||||
|
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
|
|||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=name,
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
return await provider.openai_create_vector_store(
|
||||
name=name,
|
||||
file_ids=file_ids,
|
||||
expires_after=expires_after,
|
||||
|
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
try:
|
||||
vector_store = await self.routing_table.get_provider_impl(
|
||||
vector_db.identifier
|
||||
).openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
|
@ -116,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
|
@ -235,3 +236,28 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||
# first try to get the model by identifier
|
||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
logger.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||
)
|
||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||
# to iterate (given a lack of an efficient index on the KVStore)
|
||||
models = await routing_table.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
||||
return matching_models[0]
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
return model
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def get_provider_impl(self, model_id: str) -> Any:
|
||||
model = await lookup_model(self, model_id)
|
||||
return self.impls_by_provider_id[model.provider_id]
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this model
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n"
|
||||
"Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'."
|
||||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
|
||||
provider_model_id = provider_model_id or model_id
|
||||
metadata = metadata or {}
|
||||
model_type = model_type or ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
|
||||
# an identifier different than provider_model_id implies it is an alias, so that
|
||||
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
|
||||
# so as a general rule we must use the provider_id to disambiguate.
|
||||
|
||||
if model_id != provider_model_id:
|
||||
identifier = model_id
|
||||
else:
|
||||
identifier = f"{provider_id}/{provider_model_id}"
|
||||
|
||||
model = ModelWithOwner(
|
||||
identifier=model_id,
|
||||
identifier=identifier,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
|
|
|
@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
|
||||
|
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
if toolgroup_id:
|
||||
|
@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
|
|
|
@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
|
@ -93,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
|
@ -103,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
|
@ -115,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
return result
|
||||
|
||||
|
@ -130,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
|
@ -148,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -165,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
|
@ -180,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -191,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
@ -203,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
|
@ -215,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue