refactor: optimize get_model_registry() time complexity from O(n³) to O(n)

This commit is contained in:
r3v5 2025-07-17 14:52:04 +01:00
parent 51b179e1c5
commit 2eb10de6bc
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9

View file

@ -38,42 +38,45 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
) -> tuple[list[ModelInput], bool]:
models = []
# check for conflicts in model ids
all_ids = set()
ids_conflict = False
for _, entries in available_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
# Flatten all entries with their IDs - O(n) where n is total number of model IDs
all_entries = []
all_ids = []
for provider_id, entries in available_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
models.append(
ModelInput(
model_id=identifier,
provider_model_id=entry.provider_model_id,
provider_id=provider_id,
model_type=entry.model_type,
metadata=entry.metadata,
)
)
all_entries.append((provider_id, entry, model_id))
all_ids.append(model_id)
# Check for conflicts - O(n)
id_to_count = {}
for model_id in all_ids:
id_to_count[model_id] = id_to_count.get(model_id, 0) + 1
ids_conflict = any(count > 1 for count in id_to_count.values())
if ids_conflict:
# Find the first conflicting ID for the warning message
conflicting_id = next(model_id for model_id, count in id_to_count.items() if count > 1)
rich.print(
f"[yellow]Model id {conflicting_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
)
# Build models list - O(n)
models = []
for provider_id, entry, model_id in all_entries:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
models.append(
ModelInput(
model_id=identifier,
provider_model_id=entry.provider_model_id,
provider_id=provider_id,
model_type=entry.model_type,
metadata=entry.metadata,
)
)
return models, ids_conflict