mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Undo None check and temporarily move if model check before builder
This commit is contained in:
parent
16ffe19a20
commit
490c5fb730
4 changed files with 7 additions and 7 deletions
|
@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
||||
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
|
|
|
@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
[
|
||||
|
@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
|
|
|
@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
|
|||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = supported_inference_models()
|
||||
|
||||
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
||||
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
||||
descriptors = [m.descriptor() for m in permitted_models]
|
||||
repos = [m.huggingface_repo for m in permitted_models]
|
||||
if model not in (descriptors + repos):
|
||||
model_list = "\n\t".join(repos)
|
||||
raise ValueError(
|
||||
|
|
|
@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
|
|||
return request.messages
|
||||
|
||||
allowed_models = supported_inference_models()
|
||||
descriptors = [m.descriptor() for m in allowed_models if m is not None]
|
||||
descriptors = [m.descriptor() for m in allowed_models]
|
||||
if model.descriptor() not in descriptors:
|
||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||
return request.messages
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue