mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Undo None check and temporarily move if model check before builder
This commit is contained in:
parent
16ffe19a20
commit
490c5fb730
4 changed files with 7 additions and 7 deletions
|
@ -37,8 +37,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = supported_inference_models()
|
permitted_models = supported_inference_models()
|
||||||
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
descriptors = [m.descriptor() for m in permitted_models]
|
||||||
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
repos = [m.huggingface_repo for m in permitted_models]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -34,6 +34,8 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self,
|
self,
|
||||||
[
|
[
|
||||||
|
@ -43,8 +45,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,8 @@ class VLLMConfig(BaseModel):
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
permitted_models = supported_inference_models()
|
permitted_models = supported_inference_models()
|
||||||
|
|
||||||
descriptors = [m.descriptor() for m in permitted_models if m is not None]
|
descriptors = [m.descriptor() for m in permitted_models]
|
||||||
repos = [m.huggingface_repo for m in permitted_models if m is not None]
|
repos = [m.huggingface_repo for m in permitted_models]
|
||||||
if model not in (descriptors + repos):
|
if model not in (descriptors + repos):
|
||||||
model_list = "\n\t".join(repos)
|
model_list = "\n\t".join(repos)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -179,7 +179,7 @@ def chat_completion_request_to_messages(
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
allowed_models = supported_inference_models()
|
allowed_models = supported_inference_models()
|
||||||
descriptors = [m.descriptor() for m in allowed_models if m is not None]
|
descriptors = [m.descriptor() for m in allowed_models]
|
||||||
if model.descriptor() not in descriptors:
|
if model.descriptor() not in descriptors:
|
||||||
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue