forked from phoenix-oss/llama-stack-mirror
Allow models to be registered as long as llama model is provided (#472)
This PR allows models to be registered with provider as long as the user specifies a llama model, even though the model does not match our prebuilt provider specific mapping. Test: pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py -m "together" --env TOGETHER_API_KEY=<KEY> --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
2a31163178
commit
57a9b4d57f
3 changed files with 72 additions and 21 deletions
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -17,11 +16,22 @@ from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
class TestModelRegistration:
|
class TestModelRegistration:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_unsupported_model(self, inference_stack):
|
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
||||||
_, models_impl = inference_stack
|
inference_impl, models_impl = inference_stack
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type not in (
|
||||||
|
"meta-reference",
|
||||||
|
"remote::ollama",
|
||||||
|
"remote::vllm",
|
||||||
|
"remote::tgi",
|
||||||
|
):
|
||||||
|
pytest.skip(
|
||||||
|
"Skipping test for remote inference providers since they can handle large models like 70B instruct"
|
||||||
|
)
|
||||||
|
|
||||||
# Try to register a model that's too large for local inference
|
# Try to register a model that's too large for local inference
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
await models_impl.register_model(
|
await models_impl.register_model(
|
||||||
model_id="Llama3.1-70B-Instruct",
|
model_id="Llama3.1-70B-Instruct",
|
||||||
)
|
)
|
||||||
|
@ -37,21 +47,27 @@ class TestModelRegistration:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_model(self, inference_stack):
|
async def test_register_with_llama_model(self, inference_stack):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
# Register a model to update
|
_ = await models_impl.register_model(
|
||||||
model_id = CoreModelId.llama3_1_8b_instruct.value
|
model_id="custom-model",
|
||||||
old_model = await models_impl.register_model(model_id=model_id)
|
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||||
|
|
||||||
# Update the model
|
|
||||||
new_model_id = CoreModelId.llama3_2_3b_instruct.value
|
|
||||||
updated_model = await models_impl.update_model(
|
|
||||||
model_id=model_id, provider_model_id=new_model_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retrieve the updated model to verify changes
|
with pytest.raises(ValueError) as exc_info:
|
||||||
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
await models_impl.register_model(
|
||||||
|
model_id="custom-model-2",
|
||||||
|
metadata={"llama_model": "meta-llama/Llama-2-7b"},
|
||||||
|
provider_model_id="custom-model",
|
||||||
|
)
|
||||||
|
|
||||||
# Cleanup
|
@pytest.mark.asyncio
|
||||||
await models_impl.unregister_model(model_id=model_id)
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||||
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id="custom-model-2",
|
||||||
|
metadata={"llama_model": "invalid-llama-model"},
|
||||||
|
)
|
||||||
|
|
|
@ -31,3 +31,8 @@ def supported_inference_models() -> List[str]:
|
||||||
or is_supported_safety_model(m)
|
or is_supported_safety_model(m)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
|
||||||
|
m.huggingface_repo: m.descriptor() for m in all_registered_models()
|
||||||
|
}
|
||||||
|
|
|
@ -11,6 +11,10 @@ from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference import (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
|
)
|
||||||
|
|
||||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +55,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
if identifier in self.alias_to_provider_id_map:
|
if identifier in self.alias_to_provider_id_map:
|
||||||
return self.alias_to_provider_id_map[identifier]
|
return self.alias_to_provider_id_map[identifier]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model: `{identifier}`")
|
return None
|
||||||
|
|
||||||
def get_llama_model(self, provider_model_id: str) -> str:
|
def get_llama_model(self, provider_model_id: str) -> str:
|
||||||
if provider_model_id in self.provider_id_to_llama_model_map:
|
if provider_model_id in self.provider_id_to_llama_model_map:
|
||||||
|
@ -60,8 +64,34 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model.provider_resource_id = self.get_provider_model_id(
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
model.provider_resource_id
|
if provider_resource_id:
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
else:
|
||||||
|
if model.metadata.get("llama_model") is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||||
|
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||||
|
)
|
||||||
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
|
if existing_llama_model != model.metadata["llama_model"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
model.metadata["llama_model"]
|
||||||
|
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||||
|
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||||
|
)
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
|
||||||
|
model.metadata["llama_model"]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue