Undo None check and temporarily move if model check before builder

This commit is contained in:
Connor Hack 2024-11-20 19:17:44 -08:00
parent 16ffe19a20
commit 490c5fb730
4 changed files with 7 additions and 7 deletions

View file

@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models if m is not None]
repos = [m.huggingface_repo for m in permitted_models if m is not None]
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(

View file

@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
ModelRegistryHelper.__init__(
self,
[
@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol

View file

@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models if m is not None]
repos = [m.huggingface_repo for m in permitted_models if m is not None]
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(

View file

@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
return request.messages
allowed_models = supported_inference_models()
descriptors = [m.descriptor() for m in allowed_models if m is not None]
descriptors = [m.descriptor() for m in allowed_models]
if model.descriptor() not in descriptors:
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
return request.messages