Undo None check and temporarily move if model check before builder

This commit is contained in:
Connor Hack 2024-11-20 19:17:44 -08:00
parent 16ffe19a20
commit 490c5fb730
4 changed files with 7 additions and 7 deletions

View file

@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models() permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models if m is not None] descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models if m is not None] repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(

View file

@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, self,
[ [
@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
) )
], ],
) )
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol

View file

@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models() permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models if m is not None] descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models if m is not None] repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos): if model not in (descriptors + repos):
model_list = "\n\t".join(repos) model_list = "\n\t".join(repos)
raise ValueError( raise ValueError(

View file

@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
return request.messages return request.messages
allowed_models = supported_inference_models() allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models if m is not None] descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors: if model.descriptor() not in descriptors:
cprint(f"Unsupported inference model? {model.descriptor()}", color="red") cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
return request.messages return request.messages