mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
fix: register provider model name and HF alias in run.yaml (#1304)
Each model known to the system has two identifiers: - the `provider_resource_id` (what the provider calls it) -- e.g., `accounts/fireworks/models/llama-v3p1-8b-instruct` - the `identifier` (`model_id`) under which it is registered and gets routed to the appropriate provider. We have so far used the HuggingFace repo alias as the standardized identifier you can use to refer to the model. So in the above example, we'd use `meta-llama/Llama-3.1-8B-Instruct` as the name under which it gets registered. This makes it convenient for users to refer to these models across providers. However, we forgot to register the _actual_ provider model ID also. You should be able to route via `provider_resource_id` also, of course. This change fixes this (somewhat grave) omission. *Note*: this change is additive -- more aliases work now compared to before. ## Test Plan Run the following for distro=(ollama fireworks together) ``` LLAMA_STACK_CONFIG=$distro \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=meta-llama/Llama-3.1-8B-Instruct --vision-inference-model="" ```
This commit is contained in:
parent
c54164556a
commit
04de2f84e9
49 changed files with 637 additions and 217 deletions
|
@ -24,9 +24,33 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]:
|
||||
models = []
|
||||
for provider_id, entries in available_models.items():
|
||||
for entry in entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for model_id in ids:
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id=model_id,
|
||||
provider_model_id=entry.provider_model_id,
|
||||
provider_id=provider_id,
|
||||
model_type=entry.model_type,
|
||||
metadata=entry.metadata,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
model_id: str
|
||||
doc_string: str
|
||||
|
||||
|
||||
class RunConfigSettings(BaseModel):
|
||||
provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict)
|
||||
default_models: Optional[List[ModelInput]] = None
|
||||
|
@ -110,7 +134,7 @@ class DistributionTemplate(BaseModel):
|
|||
run_config_env_vars: Optional[Dict[str, Tuple[str, str]]] = None
|
||||
container_image: Optional[str] = None
|
||||
|
||||
default_models: Optional[List[ModelInput]] = None
|
||||
available_models_by_provider: Optional[Dict[str, List[ProviderModelEntry]]] = None
|
||||
|
||||
def build_config(self) -> BuildConfig:
|
||||
return BuildConfig(
|
||||
|
@ -148,13 +172,32 @@ class DistributionTemplate(BaseModel):
|
|||
autoescape=True,
|
||||
)
|
||||
template = env.from_string(template)
|
||||
|
||||
default_models = []
|
||||
if self.available_models_by_provider:
|
||||
has_multiple_providers = len(self.available_models_by_provider.keys()) > 1
|
||||
for provider_id, model_entries in self.available_models_by_provider.items():
|
||||
for model_entry in model_entries:
|
||||
doc_parts = []
|
||||
if model_entry.aliases:
|
||||
doc_parts.append(f"aliases: {', '.join(model_entry.aliases)}")
|
||||
if has_multiple_providers:
|
||||
doc_parts.append(f"provider: {provider_id}")
|
||||
|
||||
default_models.append(
|
||||
DefaultModel(
|
||||
model_id=model_entry.provider_model_id,
|
||||
doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "",
|
||||
)
|
||||
)
|
||||
|
||||
return template.render(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
providers=self.providers,
|
||||
providers_table=providers_table,
|
||||
run_config_env_vars=self.run_config_env_vars,
|
||||
default_models=self.default_models,
|
||||
default_models=default_models,
|
||||
)
|
||||
|
||||
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue