take hugging face repo

This commit is contained in:
Dinesh Yeduguru 2024-11-18 14:57:56 -08:00
parent acf9af841b
commit 8595b2af85
3 changed files with 25 additions and 16 deletions

View file

@ -6,8 +6,6 @@
import pytest import pytest
from llama_models.datatypes import CoreModelId
# How to run this test: # How to run this test:
# #
@ -28,10 +26,12 @@ class TestModelRegistration:
"remote::vllm", "remote::vllm",
"remote::tgi", "remote::tgi",
): ):
pytest.skip("70B instruct is too big only for local inference providers") pytest.skip(
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
)
# Try to register a model that's too large for local inference # Try to register a model that's too large for local inference
with pytest.raises(Exception) as exc_info: with pytest.raises(ValueError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="Llama3.1-70B-Instruct", model_id="Llama3.1-70B-Instruct",
) )
@ -52,13 +52,13 @@ class TestModelRegistration:
_ = await models_impl.register_model( _ = await models_impl.register_model(
model_id="custom-model", model_id="custom-model",
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value}, metadata={"llama_model": "meta-llama/Llama-2-7b"},
) )
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="custom-model-2", model_id="custom-model-2",
metadata={"llama_model": CoreModelId.llama3_2_3b_instruct.value}, metadata={"llama_model": "meta-llama/Llama-2-7b"},
provider_model_id="custom-model", provider_model_id="custom-model",
) )
@ -66,7 +66,7 @@ class TestModelRegistration:
async def test_register_with_invalid_llama_model(self, inference_stack): async def test_register_with_invalid_llama_model(self, inference_stack):
_, models_impl = inference_stack _, models_impl = inference_stack
with pytest.raises(Exception) as exc_info: with pytest.raises(ValueError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="custom-model-2", model_id="custom-model-2",
metadata={"llama_model": "invalid-llama-model"}, metadata={"llama_model": "invalid-llama-model"},

View file

@ -31,3 +31,8 @@ def supported_inference_models() -> List[str]:
or is_supported_safety_model(m) or is_supported_safety_model(m)
) )
] ]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
m.huggingface_repo: m.descriptor() for m in all_registered_models()
}

View file

@ -7,11 +7,14 @@
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from typing import List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
@ -77,17 +80,18 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
) )
else: else:
# Validate that the llama model is a valid CoreModelId if (
try: model.metadata["llama_model"]
CoreModelId(model.metadata["llama_model"]) not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
except ValueError as err: ):
raise ValueError( raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Must be one of: {', '.join(m.value for m in CoreModelId)}" f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
) from err )
# Register the mapping from provider model id to llama model for future lookups
self.provider_id_to_llama_model_map[model.provider_resource_id] = ( self.provider_id_to_llama_model_map[model.provider_resource_id] = (
model.metadata["llama_model"] ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
model.metadata["llama_model"]
]
) )
return model return model