mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Since we are pushing for HF repos, we should accept them in inference configs (#497)
# What does this PR do?
As the title says.
## Test Plan
This needs
8752149f58
to also land. So the next package (0.0.54) will make this work properly.
The test is:
```bash
pytest -v -s -m "llama_3b and meta_reference" test_model_registration.py
```
This commit is contained in:
parent
b3f9e8b2f2
commit
e84d4436b5
5 changed files with 14 additions and 8 deletions
|
@ -37,8 +37,10 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
|
|
|
@ -48,8 +48,11 @@ class VLLMConfig(BaseModel):
|
|||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
|
|
|
@ -11,7 +11,6 @@ import pytest
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||
# -m "meta_reference"
|
||||
# --env TOGETHER_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
class TestModelRegistration:
|
||||
|
|
|
@ -22,9 +22,9 @@ def is_supported_safety_model(model: Model) -> bool:
|
|||
]
|
||||
|
||||
|
||||
def supported_inference_models() -> List[str]:
|
||||
def supported_inference_models() -> List[Model]:
|
||||
return [
|
||||
m.descriptor()
|
||||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
|
|
|
@ -178,7 +178,9 @@ def chat_completion_request_to_messages(
|
|||
cprint(f"Could not resolve model {llama_model}", color="red")
|
||||
return request.messages
|
||||
|
||||
if model.descriptor() not in supported_inference_models():
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||
return request.messages
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue