mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Inference to use provider resource id to register and validate (#428)
This PR changes the way model id gets translated to the final model name that gets passed through the provider. Major changes include: 1) Providers are responsible for registering an object and as part of the registration returning the object with the correct provider specific name of the model provider_resource_id 2) To help with the common look ups different names a new ModelLookup class is created. Tested all inference providers including together, fireworks, vllm, ollama, meta reference and bedrock
This commit is contained in:
parent
e51107e019
commit
fdff24e77a
21 changed files with 460 additions and 290 deletions
|
@ -4,32 +4,61 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||
for model in all_registered_models():
|
||||
if model.descriptor() == model_descriptor:
|
||||
return model.huggingface_repo
|
||||
return None
|
||||
|
||||
|
||||
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[
|
||||
model_descriptor,
|
||||
get_huggingface_repo(model_descriptor),
|
||||
],
|
||||
llama_model=model_descriptor,
|
||||
)
|
||||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_aliases: List[ModelAlias]):
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for alias_obj in model_aliases:
|
||||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.provider_model_id
|
||||
)
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.llama_model
|
||||
)
|
||||
|
||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
||||
|
||||
def map_to_provider_model(self, identifier: str) -> str:
|
||||
model = resolve_model(identifier)
|
||||
if not model:
|
||||
def get_provider_model_id(self, identifier: str) -> str:
|
||||
if identifier in self.alias_to_provider_id_map:
|
||||
return self.alias_to_provider_id_map[identifier]
|
||||
else:
|
||||
raise ValueError(f"Unknown model: `{identifier}`")
|
||||
|
||||
if identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
||||
)
|
||||
def get_llama_model(self, provider_model_id: str) -> str:
|
||||
return self.provider_id_to_llama_model_map[provider_model_id]
|
||||
|
||||
return self.stack_to_provider_models_map[identifier]
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model.provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
if model.identifier not in self.stack_to_provider_models_map:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||
)
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue