diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 080e33be0..8b155da30 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -87,6 +87,9 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ + model = await self.model_store.get_model(config.model) + base_model = model.metadata["base_model"] or self.model_id + self.model = resolve_model(base_model) model = resolve_model(config.model) llama_model = model.core_model_id.value diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e7abde227..02e611997 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -24,6 +24,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, ) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -41,9 +42,18 @@ class MetaReferenceInferenceImpl( ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config - model = resolve_model(config.model) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") + self.model_id = config.model + + async def initialize(self) -> None: + model = await self.model_store.get_model(self.model_id) + base_model = model.metadata["base_model"] or self.model_id + self.model = resolve_model(base_model) + + if self.model is None: + raise RuntimeError( + f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model" + ) + self.model_registry_helper = ModelRegistryHelper( [ build_model_alias( @@ -52,11 +62,9 @@ class MetaReferenceInferenceImpl( ) ], ) - self.model = model - # verify that the checkpoint actually is for this model lol - async def initialize(self) -> None: log.info(f"Loading model `{self.model.descriptor()}`") + if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() @@ -67,11 +75,13 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() - def check_model(self, request) -> None: - model = resolve_model(request.model) + async def check_model(self, request) -> None: + request_model = await self.model_store.get_model(request.model) + base_model = request_model.metadata["base_model"] or request.model + model = resolve_model(base_model) if model is None: raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" + f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model" ) elif model.descriptor() != self.model.descriptor(): raise RuntimeError( @@ -107,7 +117,7 @@ class MetaReferenceInferenceImpl( stream=stream, logprobs=logprobs, ) - self.check_model(request) + await self.check_model(request) request = await request_with_localized_media(request) if request.stream: @@ -232,7 +242,7 @@ class MetaReferenceInferenceImpl( stream=stream, logprobs=logprobs, ) - self.check_model(request) + await self.check_model(request) request = await request_with_localized_media(request) if self.config.create_distributed_process_group: