mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
Merge 2eb10de6bc
into cbe89d2bdd
This commit is contained in:
commit
ec31b6c179
1 changed files with 34 additions and 31 deletions
|
@ -38,42 +38,45 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
||||||
def get_model_registry(
|
def get_model_registry(
|
||||||
available_models: dict[str, list[ProviderModelEntry]],
|
available_models: dict[str, list[ProviderModelEntry]],
|
||||||
) -> tuple[list[ModelInput], bool]:
|
) -> tuple[list[ModelInput], bool]:
|
||||||
models = []
|
# Flatten all entries with their IDs - O(n) where n is total number of model IDs
|
||||||
|
all_entries = []
|
||||||
# check for conflicts in model ids
|
all_ids = []
|
||||||
all_ids = set()
|
|
||||||
ids_conflict = False
|
|
||||||
|
|
||||||
for _, entries in available_models.items():
|
|
||||||
for entry in entries:
|
|
||||||
ids = [entry.provider_model_id] + entry.aliases
|
|
||||||
for model_id in ids:
|
|
||||||
if model_id in all_ids:
|
|
||||||
ids_conflict = True
|
|
||||||
rich.print(
|
|
||||||
f"[yellow]Model id {model_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
all_ids.update(ids)
|
|
||||||
if ids_conflict:
|
|
||||||
break
|
|
||||||
if ids_conflict:
|
|
||||||
break
|
|
||||||
|
|
||||||
for provider_id, entries in available_models.items():
|
for provider_id, entries in available_models.items():
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
ids = [entry.provider_model_id] + entry.aliases
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
for model_id in ids:
|
for model_id in ids:
|
||||||
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
all_entries.append((provider_id, entry, model_id))
|
||||||
models.append(
|
all_ids.append(model_id)
|
||||||
ModelInput(
|
|
||||||
model_id=identifier,
|
# Check for conflicts - O(n)
|
||||||
provider_model_id=entry.provider_model_id,
|
id_to_count = {}
|
||||||
provider_id=provider_id,
|
for model_id in all_ids:
|
||||||
model_type=entry.model_type,
|
id_to_count[model_id] = id_to_count.get(model_id, 0) + 1
|
||||||
metadata=entry.metadata,
|
|
||||||
)
|
ids_conflict = any(count > 1 for count in id_to_count.values())
|
||||||
)
|
|
||||||
|
if ids_conflict:
|
||||||
|
# Find the first conflicting ID for the warning message
|
||||||
|
conflicting_id = next(model_id for model_id, count in id_to_count.items() if count > 1)
|
||||||
|
rich.print(
|
||||||
|
f"[yellow]Model id {conflicting_id} conflicts; all model ids will be prefixed with provider id[/yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build models list - O(n)
|
||||||
|
models = []
|
||||||
|
for provider_id, entry, model_id in all_entries:
|
||||||
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||||
|
models.append(
|
||||||
|
ModelInput(
|
||||||
|
model_id=identifier,
|
||||||
|
provider_model_id=entry.provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
model_type=entry.model_type,
|
||||||
|
metadata=entry.metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return models, ids_conflict
|
return models, ids_conflict
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue