fix: list models only for active providers

There has been an error rolling around where we can retrieve a model when doing something like a chat completion but then we hit issues when trying to associate that model with an active provider.

This is a common thing that happens when:
1. you run the stack with say remote::ollama
2. you register a model, say llama3.2:3b
3. you do some completions, etc
4. you kill the server
5. you `unset OLLAMA_URL`
6. you re-start the stack
7. you do `llama-stack-client models list`

```
├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤
│ embedding     │ all-minilm                                                                       │ all-minilm:l6-v2                                                     │ {'embedding_dimension': 384.0,        │ ollama                   │
│               │                                                                                  │                                                                      │ 'context_length': 512.0}              │                          │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤
│ llm           │ llama3.2:3b                                                                      │ llama3.2:3b                                                          │                                       │ ollama                   │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤
│ embedding     │ ollama/all-minilm:l6-v2                                                          │ all-minilm:l6-v2                                                     │ {'embedding_dimension': 384.0,        │ ollama                   │
│               │                                                                                  │                                                                      │ 'context_length': 512.0}              │                          │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤
│ llm           │ ollama/llama3.2:3b                                                               │ llama3.2:3b                                                          │                                       │ ollama                   │
├───────────────┼──────────────────────────────────────────────────────────────────────────────────┼──────────────────────────────────────────────────────────────────────┼───────────────────────────────────────┼──────────────────────────┤

```

This shouldn't be happening, `ollama` isn't a provider running, and the only reason the model is popping up is because its in the dist_registry (on disk).

While its nice to have this static store so that if I go and `export OLLAMA_URL=..` again, it can read from the store, it shouldn't _always_ be reading and returning these models from the store

now if you `llama-stack-client models list` with this change, no more llama3.2:3b appears.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-08-14 08:08:39 -04:00
parent 61582f327c
commit 3f580503ef
9 changed files with 25 additions and 17 deletions

View file

@ -171,7 +171,7 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db")
all_stores = []
for vector_db in vector_dbs:
try:

View file

@ -19,7 +19,7 @@ logger = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)

View file

@ -232,15 +232,21 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
return [obj for obj in objs if obj.type == type]
async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]:
all_objs = await self.get_all_with_type(type=type)
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
if all_objs:
all_objs = [
obj
for obj in all_objs
if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
and obj.provider_id in self.impls_by_provider_id
]
return filtered_objs
return all_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
)
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model")
models = await routing_table.get_all_with_type_filtered("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ModelNotFoundError(model_id)

View file

@ -31,7 +31,7 @@ logger = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)

View file

@ -43,10 +43,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
return ListModelsResponse(data=await self.get_all_with_type_filtered("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
models = await self.get_all_with_type_filtered("model")
openai_models = [
OpenAIModel(
id=model.identifier,
@ -122,7 +122,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
existing_models = await self.get_all_with_type_filtered("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of

View file

@ -24,7 +24,9 @@ logger = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
return ListScoringFunctionsResponse(
data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)

View file

@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
toolgroups = await self.get_all_with_type_filtered("tool_group")
all_tools = []
for toolgroup in toolgroups:
@ -83,7 +83,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)

View file

@ -35,7 +35,7 @@ logger = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)