diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index fb2528873..5ed177022 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -38,42 +38,45 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], ) -> tuple[list[ModelInput], bool]: - models = [] - - # check for conflicts in model ids - all_ids = set() - ids_conflict = False - - for _, entries in available_models.items(): - for entry in entries: - ids = [entry.provider_model_id] + entry.aliases - for model_id in ids: - if model_id in all_ids: - ids_conflict = True - rich.print( - f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]" - ) - break - all_ids.update(ids) - if ids_conflict: - break - if ids_conflict: - break + # Flatten all entries with their IDs - O(n) where n is total number of model IDs + all_entries = [] + all_ids = [] for provider_id, entries in available_models.items(): for entry in entries: ids = [entry.provider_model_id] + entry.aliases for model_id in ids: - identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id - models.append( - ModelInput( - model_id=identifier, - provider_model_id=entry.provider_model_id, - provider_id=provider_id, - model_type=entry.model_type, - metadata=entry.metadata, - ) - ) + all_entries.append((provider_id, entry, model_id)) + all_ids.append(model_id) + + # Check for conflicts - O(n) + id_to_count = {} + for model_id in all_ids: + id_to_count[model_id] = id_to_count.get(model_id, 0) + 1 + + ids_conflict = any(count > 1 for count in id_to_count.values()) + + if ids_conflict: + # Find the first conflicting ID for the warning message + conflicting_id = next(model_id for model_id, count in id_to_count.items() if count > 1) + rich.print( + f"[yellow]Model id {conflicting_id} conflicts; all model ids will be prefixed with provider id[/yellow]" + ) + + # Build models list - O(n) + models = [] + for provider_id, entry, model_id in all_entries: + identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id + models.append( + ModelInput( + model_id=identifier, + provider_model_id=entry.provider_model_id, + provider_id=provider_id, + model_type=entry.model_type, + metadata=entry.metadata, + ) + ) + return models, ids_conflict