llama-stack/llama_stack/providers/utils/inference/__init__.py
Dinesh Yeduguru 57a9b4d57f
Allow models to be registered as long as llama model is provided (#472)
This PR allows models to be registered with provider as long as the user
specifies a llama model, even though the model does not match our
prebuilt provider specific mapping.
Test:
pytest -v -s
llama_stack/providers/tests/inference/test_model_registration.py -m
"together" --env TOGETHER_API_KEY=<KEY>

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
2024-11-18 15:05:29 -08:00

38 lines
1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models
def is_supported_safety_model(model: Model) -> bool:
if model.quantization_format != CheckpointQuantizationFormat.bf16:
return False
model_id = model.core_model_id
return model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
]
def supported_inference_models() -> List[str]:
return [
m.descriptor()
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or is_supported_safety_model(m)
)
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
m.huggingface_repo: m.descriptor() for m in all_registered_models()
}