mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
feat: Add tags field for models with dynamic and user-defined population
- Implemented a Python function to extract tags from the model identifier field for dynamic population. - Enabled users to specify tags manually when registering a model. - Tags are now included when retrieving model data. Signed-off-by: Habeb Nawatha <habeb.naw@outlook.com>
This commit is contained in:
parent
815f4af6cf
commit
3faa79d76c
2 changed files with 32 additions and 1 deletions
|
@ -20,6 +20,11 @@ class CommonModelFields(BaseModel):
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tags: Dict[str, str] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Tags associated with this model as a dictionary",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
|
@ -69,6 +74,7 @@ class Models(Protocol):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
|
|
|
@ -66,6 +66,21 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tags_from_identifier(identifier: str) -> Dict[str, str]:
|
||||||
|
tags = {}
|
||||||
|
version_match = re.search(r"(\d+\.\d+)", identifier)
|
||||||
|
model_type_match = re.search(r"(Instruct|Vision|Other|chat)", identifier)
|
||||||
|
size_match = re.search(r"(\d+)(B|M)", identifier)
|
||||||
|
|
||||||
|
if version_match:
|
||||||
|
tags["llama_version"] = version_match.group(1)
|
||||||
|
if model_type_match:
|
||||||
|
tags["model_type"] = model_type_match.group(1)
|
||||||
|
if size_match:
|
||||||
|
tags["model_size"] = size_match.group(1) + size_match.group(2)
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -198,7 +213,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[Model]:
|
async def list_models(self) -> List[Model]:
|
||||||
return await self.get_all_with_type("model")
|
models = await self.get_all_with_type("model")
|
||||||
|
for model in models:
|
||||||
|
if not model.tags: # If there are no tags, assign them
|
||||||
|
tags = extract_tags_from_identifier(model.identifier)
|
||||||
|
model.tags = tags
|
||||||
|
await self.dist_registry.register(model)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||||
return await self.get_object_by_identifier("model", identifier)
|
return await self.get_object_by_identifier("model", identifier)
|
||||||
|
@ -210,6 +232,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -229,12 +252,14 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Embedding model must have an embedding dimension in its metadata"
|
"Embedding model must have an embedding dimension in its metadata"
|
||||||
)
|
)
|
||||||
|
tags = extract_tags_from_identifier(model_id)
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue