forked from phoenix-oss/llama-stack-mirror
fix: register provider model name and HF alias in run.yaml (#1304)
Each model known to the system has two identifiers: - the `provider_resource_id` (what the provider calls it) -- e.g., `accounts/fireworks/models/llama-v3p1-8b-instruct` - the `identifier` (`model_id`) under which it is registered and gets routed to the appropriate provider. We have so far used the HuggingFace repo alias as the standardized identifier you can use to refer to the model. So in the above example, we'd use `meta-llama/Llama-3.1-8B-Instruct` as the name under which it gets registered. This makes it convenient for users to refer to these models across providers. However, we forgot to register the _actual_ provider model ID also. You should be able to route via `provider_resource_id` also, of course. This change fixes this (somewhat grave) omission. *Note*: this change is additive -- more aliases work now compared to before. ## Test Plan Run the following for distro=(ollama fireworks together) ``` LLAMA_STACK_CONFIG=$distro \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=meta-llama/Llama-3.1-8B-Instruct --vision-inference-model="" ```
This commit is contained in:
parent
c54164556a
commit
04de2f84e9
49 changed files with 637 additions and 217 deletions
|
@ -13,7 +13,6 @@ from llama_stack.distribution.datatypes import (
|
|||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -28,7 +27,7 @@ from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
|||
from llama_stack.providers.remote.inference.groq.models import MODEL_ENTRIES as GROQ_MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
|
||||
|
||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||
|
@ -61,8 +60,7 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
|||
),
|
||||
]
|
||||
inference_providers = []
|
||||
default_models = []
|
||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
available_models = {}
|
||||
for provider_id, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
|
@ -71,21 +69,12 @@ def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
|||
config=config,
|
||||
)
|
||||
)
|
||||
default_models.extend(
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id=provider_id,
|
||||
model_type=m.model_type,
|
||||
metadata=m.metadata,
|
||||
)
|
||||
for m in model_entries
|
||||
)
|
||||
return inference_providers, default_models
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, default_models = get_inference_providers()
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
|
@ -139,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
)
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
|
@ -146,7 +136,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
default_models=[],
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue