mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
take hugging face repo
This commit is contained in:
parent
acf9af841b
commit
8595b2af85
3 changed files with 25 additions and 16 deletions
|
@ -6,8 +6,6 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -28,10 +26,12 @@ class TestModelRegistration:
|
|||
"remote::vllm",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip("70B instruct is too big only for local inference providers")
|
||||
pytest.skip(
|
||||
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
|
||||
)
|
||||
|
||||
# Try to register a model that's too large for local inference
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="Llama3.1-70B-Instruct",
|
||||
)
|
||||
|
@ -52,13 +52,13 @@ class TestModelRegistration:
|
|||
|
||||
_ = await models_impl.register_model(
|
||||
model_id="custom-model",
|
||||
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value},
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={"llama_model": CoreModelId.llama3_2_3b_instruct.value},
|
||||
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||
provider_model_id="custom-model",
|
||||
)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class TestModelRegistration:
|
|||
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await models_impl.register_model(
|
||||
model_id="custom-model-2",
|
||||
metadata={"llama_model": "invalid-llama-model"},
|
||||
|
|
|
@ -31,3 +31,8 @@ def supported_inference_models() -> List[str]:
|
|||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
|
||||
m.huggingface_repo: m.descriptor() for m in all_registered_models()
|
||||
}
|
||||
|
|
|
@ -7,11 +7,14 @@
|
|||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||
|
||||
|
||||
|
@ -77,17 +80,18 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
# Validate that the llama model is a valid CoreModelId
|
||||
try:
|
||||
CoreModelId(model.metadata["llama_model"])
|
||||
except ValueError as err:
|
||||
if (
|
||||
model.metadata["llama_model"]
|
||||
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(m.value for m in CoreModelId)}"
|
||||
) from err
|
||||
# Register the mapping from provider model id to llama model for future lookups
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
model.metadata["llama_model"]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
|
||||
model.metadata["llama_model"]
|
||||
]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue