From 35b1a6f2dc4d7f2f9305395e77e15ebe7b60122b Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 16 Dec 2024 16:44:15 -0800 Subject: [PATCH] temp commit --- llama_stack/distribution/routers/routers.py | 1 + .../inference/meta_reference/__init__.py | 5 +- .../inline/inference/meta_reference/config.py | 4 +- .../inference/meta_reference/generation.py | 9 +- .../inference/meta_reference/inference.py | 119 +++++------------- .../meta_reference/model_parallel.py | 19 ++- .../templates/meta-reference-gpu/run.yaml | 10 +- 7 files changed, 54 insertions(+), 113 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index eeeaf5c52..06c232456 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -74,6 +74,7 @@ class InferenceRouter(Inference): self, routing_table: RoutingTable, ) -> None: + print("InferenceRouter init") self.routing_table = routing_table async def initialize(self) -> None: diff --git a/llama_stack/providers/inline/inference/meta_reference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py index 36d27fa56..b7e935ebf 100644 --- a/llama_stack/providers/inline/inference/meta_reference/__init__.py +++ b/llama_stack/providers/inline/inference/meta_reference/__init__.py @@ -18,7 +18,6 @@ async def get_provider_impl( print("get_provider_impl") impl = MetaReferenceInferenceImpl(config) - if config.model: - # pre-load the model if the model is in the config - await impl.initialize() + + print("after MetaReferenceInferenceImpl") return impl diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index ae04dc780..ffc27c08c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -16,7 +16,9 @@ from llama_stack.providers.utils.inference import supported_inference_models class MetaReferenceInferenceConfig(BaseModel): - model: Optional[str] = None + model: Optional[str] = ( + None # this is a placeholder to indicate inference model id, not actually being used + ) torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 92fbaa164..eebe7b61d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -79,7 +79,7 @@ class Llama: config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig ], - request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, + model_id: str, ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -88,12 +88,7 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - if config.model: - model = resolve_model(config.model) - elif request: - model = resolve_model(request.model) - else: - raise RuntimeError("you need to provide a model for inference") + model = resolve_model(model_id) llama_model = model.core_model_id.value diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0b77f9d36..53edaf96c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -43,103 +43,68 @@ class MetaReferenceInferenceImpl( ModelsProtocolPrivate, ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: - print("MetaReferenceInferenceImpl init") self.config = config self.model = None - self.model_registry_helper = None - if config.model: - model = resolve_model(config.model) - if model is None: - raise RuntimeError( - f"Unknown model: {config.model}, Run `llama model list`" - ) - self.model_registry_helper = ModelRegistryHelper( - [ - build_model_alias( - model.descriptor(), - model.core_model_id.value, - ) - ], - ) - self.model = model - # verify that the checkpoint actually is for this model lol - else: - print("inference model isn't pre-loaded") - async def _setup_model(self, model_id: str) -> Optional[Model]: - model = resolve_model(model_id) - if model is None: - raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`") - # self.model_registry_helper = ModelRegistryHelper( - # [ - # build_model_alias( - # model.descriptor(), - # model.core_model_id.value, - # ) - # ], - # ) - - # return await self.register_model(model) - return model - - async def initialize(self) -> None: - if self.model is None: - raise RuntimeError("model hasn't been setup yet") - log.info(f"Loading model `{self.model.descriptor()}`") + async def initialize(self, model_id) -> None: + log.info(f"Loading model `{model_id}`") if self.config.create_distributed_process_group: - self.generator = LlamaModelParallelGenerator(self.config) + self.generator = LlamaModelParallelGenerator(self.config, model_id) self.generator.start() else: - self.generator = Llama.build(self.config) + self.generator = Llama.build(self.config, model_id) - async def _lazy_initialize(self, request) -> None: - if self.model is None: - raise RuntimeError("model hasn't been setup yet") - print(f"Lazy loading model `{self.model.descriptor()}`") - if self.config.create_distributed_process_group: - # with LlamaModelParallelGenerator(self.config, request) as resouce: - self.generator = LlamaModelParallelGenerator(self.config, request) - self.generator.start() - else: - self.generator = Llama.build(self.config, request) + self.model = model_id async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() def check_model(self, request) -> None: - model = resolve_model(request.model) - if model is None: + if self.model is None: + raise RuntimeError( + "Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first" + ) + inference_model = resolve_model(self.model) + requested_model = resolve_model(request.model) + if requested_model is None: raise RuntimeError( f"Unknown model: {request.model}, Run `llama model list`" ) - elif self.model and model.descriptor() != self.model.descriptor(): + elif requested_model.descriptor() != inference_model.descriptor(): raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" + f"Model mismatch: {request.model} != {inference_model.descriptor()}" ) async def unregister_model(self, model_id: str) -> None: pass async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: - if self.model_registry_helper is None: - llama_model = resolve_model(model.identifier) - if llama_model is None: - raise RuntimeError( - f"Unknown model: {model.identifier}, Run `llama model list`" - ) - self.model_registry_helper = ModelRegistryHelper( - [ - build_model_alias( - llama_model.descriptor(), - llama_model.core_model_id.value, - ) - ], + llama_model = resolve_model(model.identifier) + if llama_model is None: + raise RuntimeError( + f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list" ) + self.model_registry_helper = ModelRegistryHelper( + [ + build_model_alias( + llama_model.descriptor(), + llama_model.core_model_id.value, + ) + ], + ) model = await self.model_registry_helper.register_model(model) print("model type", type(model)) if model.model_type == ModelType.embedding_model: self._load_sentence_transformer_model(model.provider_resource_id) + + if ( + model.metadata + and "skip_initialize" in model.metadata + and model.metadata["skip_initialize"] + ): + return model + await self.initialize(model.identifier) return model async def completion( @@ -171,10 +136,6 @@ class MetaReferenceInferenceImpl( return await self._nonstream_completion(request) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - if self.model is None: - self.model = await self._setup_model(request.model) - await self._lazy_initialize(request) - def impl(): stop_reason = None @@ -224,10 +185,6 @@ class MetaReferenceInferenceImpl( async def _nonstream_completion( self, request: CompletionRequest ) -> CompletionResponse: - if self.model is None: - self.model = await self._setup_model(request.model) - await self._lazy_initialize(request) - def impl(): tokens = [] logprobs = [] @@ -310,10 +267,6 @@ class MetaReferenceInferenceImpl( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - if self.model is None: - self.model = await self._setup_model(request.model) - await self._lazy_initialize(request) - def impl(): tokens = [] logprobs = [] @@ -359,10 +312,6 @@ class MetaReferenceInferenceImpl( async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - if self.model is None: - self.model = await self._setup_model(request.model) - await self._lazy_initialize(request) - def impl(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index ac676a202..3eb11bf5a 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -7,7 +7,7 @@ import os from copy import deepcopy from functools import partial -from typing import Any, Generator, Optional, Union +from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -36,9 +36,9 @@ class ModelRunner: def init_model_cb( config: MetaReferenceInferenceConfig, - request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, + model_id: str, ): - llama = Llama.build(config, request) + llama = Llama.build(config, model_id) return ModelRunner(llama) @@ -56,17 +56,12 @@ class LlamaModelParallelGenerator: def __init__( self, config: MetaReferenceInferenceConfig, - request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, + model_id: str, ): print("LlamaModelParallelGenerator init") self.config = config - self.request = request - if config.model: - self.model = resolve_model(config.model) - elif request: - self.model = resolve_model(request.model) - else: - raise RuntimeError("you need to provide a model for inference") + self.model_id = model_id + self.model = resolve_model(model_id) # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going @@ -89,7 +84,7 @@ class LlamaModelParallelGenerator: self.group = ModelParallelProcessGroup( model_parallel_size, - init_model_cb=partial(init_model_cb, self.config, self.request), + init_model_cb=partial(init_model_cb, self.config, self.model_id), ) self.group.start() return self diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index af9b4a89c..175988f7c 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -16,7 +16,7 @@ providers: - provider_id: meta-reference-inference provider_type: inline::meta-reference config: - # model: ${env.INFERENCE_MODEL} + model: ${env.INFERENCE_MODEL} # please make sure your inference model here is added as resource max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} memory: @@ -73,10 +73,10 @@ metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db models: [] -# - metadata: {} -# model_id: ${env.INFERENCE_MODEL} -# provider_id: meta-reference-inference -# provider_model_id: null +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: meta-reference-inference + provider_model_id: null shields: [] memory_banks: [] datasets: []