feat(registry): more flexible model lookup (#2859)

This PR updates model registration and lookup behavior to be slightly
more general / flexible. See
https://github.com/meta-llama/llama-stack/issues/2843 for more details.

Note that this change is backwards compatible given the design of the
`lookup_model()` method.

## Test Plan

Added unit tests
This commit is contained in:
Ashwin Bharambe 2025-07-22 15:22:48 -07:00 committed by GitHub
parent 9736f096f6
commit 3b83032555
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 265 additions and 75 deletions

View file

@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -97,7 +99,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
@ -110,7 +113,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
@ -123,7 +127,8 @@ class EvalRouter(Eval):
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel(
self,
@ -131,7 +136,8 @@ class EvalRouter(Eval):
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id,
job_id,
)
@ -142,7 +148,8 @@ class EvalRouter(Eval):
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id,
job_id,
)

View file

@ -231,7 +231,7 @@ class InferenceRouter(Inference):
logprobs=logprobs,
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
content=content,
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings(
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
suffix=suffix,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if self.store:
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params)
async def list_chat_completions(

View file

@ -50,7 +50,8 @@ class SafetyRouter(Safety):
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
provider = await self.routing_table.get_provider_impl(shield_id)
return await provider.run_shield(
shield_id=shield_id,
messages=messages,
params=params,

View file

@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_db_ids, query_config)
async def insert(
self,
@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__(
self,
@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)

View file

@ -104,7 +104,8 @@ class VectorIORouter(VectorIO):
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -113,7 +114,8 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
provider = await self.routing_table.get_provider_impl(vector_db_id)
return await provider.query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
@ -146,7 +148,8 @@ class VectorIORouter(VectorIO):
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
return await provider.openai_create_vector_store(
name=name,
file_ids=file_ids,
expires_after=expires_after,
@ -172,9 +175,8 @@ class VectorIORouter(VectorIO):
all_stores = []
for vector_db in vector_dbs:
try:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
except Exception as e:
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")