From de44af15017519b4e99852ef44b344d19d1759e7 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 12 Dec 2024 21:44:03 -0800 Subject: [PATCH] temp commit --- llama_stack/distribution/routers/routers.py | 1 + .../distribution/routers/routing_tables.py | 12 ++ .../inference/meta_reference/__init__.py | 6 +- .../inline/inference/meta_reference/config.py | 7 +- .../inference/meta_reference/generation.py | 12 +- .../inference/meta_reference/inference.py | 118 +++++++++++++----- .../meta_reference/model_parallel.py | 35 ++++-- .../providers/utils/inference/__init__.py | 3 +- .../templates/meta-reference-gpu/run.yaml | 12 +- 9 files changed, 153 insertions(+), 53 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 51be318cb..eeeaf5c52 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -90,6 +90,7 @@ class InferenceRouter(Inference): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> None: + print("inference router") await self.routing_table.register_model( model_id, provider_model_id, provider_id, metadata, model_type ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index bc3de8be0..2b2ed9b4d 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -32,6 +32,7 @@ def get_impl_api(p: Any) -> Api: async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: api = get_impl_api(p) + print("registering object with provider", api) assert obj.provider_id != "remote", "Remote provider should not be registered" @@ -169,6 +170,7 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object( self, obj: RoutableObjectWithProvider ) -> RoutableObjectWithProvider: + # Get existing objects from registry existing_obj = await self.dist_registry.get(obj.type, obj.identifier) @@ -181,7 +183,12 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] + if obj is None: + print("obj is None") + registered_obj = await register_object_with_provider(obj, p) + if registered_obj is None: + print("registered_obj is None") # TODO: This needs to be fixed for all APIs once they return the registered object if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) @@ -211,6 +218,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, ) -> Model: + print("register_model", model_id) if provider_model_id is None: provider_model_id = model_id if provider_id is None: @@ -239,7 +247,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata=metadata, model_type=model_type, ) + if model is None: + print("model is None!!!") + print("before registered_model") registered_model = await self.register_object(model) + print("after registered_model") return registered_model async def unregister_model(self, model_id: str) -> None: diff --git a/llama_stack/providers/inline/inference/meta_reference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py index 9c923490d..36d27fa56 100644 --- a/llama_stack/providers/inline/inference/meta_reference/__init__.py +++ b/llama_stack/providers/inline/inference/meta_reference/__init__.py @@ -15,6 +15,10 @@ async def get_provider_impl( ): from .inference import MetaReferenceInferenceImpl + print("get_provider_impl") + impl = MetaReferenceInferenceImpl(config) - await impl.initialize() + if config.model: + # pre-load the model if the model is in the config + await impl.initialize() return impl diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 04058d55d..ae04dc780 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -10,16 +10,13 @@ from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F401, F403 -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, field_validator from llama_stack.providers.utils.inference import supported_inference_models class MetaReferenceInferenceConfig(BaseModel): - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", - ) + model: Optional[str] = None torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 8b155da30..92fbaa164 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -79,6 +79,7 @@ class Llama: config: Union[ MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig ], + request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, ): """ Build a Llama instance by initializing and loading a model checkpoint. @@ -87,10 +88,13 @@ class Llama: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ - model = await self.model_store.get_model(config.model) - base_model = model.metadata["base_model"] or self.model_id - self.model = resolve_model(base_model) - model = resolve_model(config.model) + if config.model: + model = resolve_model(config.model) + elif request: + model = resolve_model(request.model) + else: + raise RuntimeError("you need to provide a model for inference") + llama_model = model.core_model_id.value if not torch.distributed.is_initialized(): diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 02e611997..0b77f9d36 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -11,6 +11,8 @@ from typing import AsyncGenerator, List from llama_models.sku_list import resolve_model +from llama_stack.apis.models import Model as LlamaStackModel + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.providers.utils.inference.model_registry import build_model_alias @@ -41,49 +43,77 @@ class MetaReferenceInferenceImpl( ModelsProtocolPrivate, ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: + print("MetaReferenceInferenceImpl init") self.config = config - self.model_id = config.model + self.model = None + self.model_registry_helper = None + if config.model: + model = resolve_model(config.model) + if model is None: + raise RuntimeError( + f"Unknown model: {config.model}, Run `llama model list`" + ) + self.model_registry_helper = ModelRegistryHelper( + [ + build_model_alias( + model.descriptor(), + model.core_model_id.value, + ) + ], + ) + self.model = model + # verify that the checkpoint actually is for this model lol + else: + print("inference model isn't pre-loaded") + + async def _setup_model(self, model_id: str) -> Optional[Model]: + model = resolve_model(model_id) + if model is None: + raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`") + # self.model_registry_helper = ModelRegistryHelper( + # [ + # build_model_alias( + # model.descriptor(), + # model.core_model_id.value, + # ) + # ], + # ) + + # return await self.register_model(model) + return model async def initialize(self) -> None: - model = await self.model_store.get_model(self.model_id) - base_model = model.metadata["base_model"] or self.model_id - self.model = resolve_model(base_model) - if self.model is None: - raise RuntimeError( - f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model" - ) - - self.model_registry_helper = ModelRegistryHelper( - [ - build_model_alias( - model.descriptor(), - model.core_model_id.value, - ) - ], - ) - + raise RuntimeError("model hasn't been setup yet") log.info(f"Loading model `{self.model.descriptor()}`") - if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator(self.config) self.generator.start() else: self.generator = Llama.build(self.config) + async def _lazy_initialize(self, request) -> None: + if self.model is None: + raise RuntimeError("model hasn't been setup yet") + print(f"Lazy loading model `{self.model.descriptor()}`") + if self.config.create_distributed_process_group: + # with LlamaModelParallelGenerator(self.config, request) as resouce: + self.generator = LlamaModelParallelGenerator(self.config, request) + self.generator.start() + else: + self.generator = Llama.build(self.config, request) + async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() - async def check_model(self, request) -> None: - request_model = await self.model_store.get_model(request.model) - base_model = request_model.metadata["base_model"] or request.model - model = resolve_model(base_model) + def check_model(self, request) -> None: + model = resolve_model(request.model) if model is None: raise RuntimeError( - f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model" + f"Unknown model: {request.model}, Run `llama model list`" ) - elif model.descriptor() != self.model.descriptor(): + elif self.model and model.descriptor() != self.model.descriptor(): raise RuntimeError( f"Model mismatch: {request.model} != {self.model.descriptor()}" ) @@ -91,8 +121,23 @@ class MetaReferenceInferenceImpl( async def unregister_model(self, model_id: str) -> None: pass - async def register_model(self, model: Model) -> Model: + async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: + if self.model_registry_helper is None: + llama_model = resolve_model(model.identifier) + if llama_model is None: + raise RuntimeError( + f"Unknown model: {model.identifier}, Run `llama model list`" + ) + self.model_registry_helper = ModelRegistryHelper( + [ + build_model_alias( + llama_model.descriptor(), + llama_model.core_model_id.value, + ) + ], + ) model = await self.model_registry_helper.register_model(model) + print("model type", type(model)) if model.model_type == ModelType.embedding_model: self._load_sentence_transformer_model(model.provider_resource_id) return model @@ -117,7 +162,7 @@ class MetaReferenceInferenceImpl( stream=stream, logprobs=logprobs, ) - await self.check_model(request) + self.check_model(request) request = await request_with_localized_media(request) if request.stream: @@ -126,6 +171,10 @@ class MetaReferenceInferenceImpl( return await self._nonstream_completion(request) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + if self.model is None: + self.model = await self._setup_model(request.model) + await self._lazy_initialize(request) + def impl(): stop_reason = None @@ -175,6 +224,10 @@ class MetaReferenceInferenceImpl( async def _nonstream_completion( self, request: CompletionRequest ) -> CompletionResponse: + if self.model is None: + self.model = await self._setup_model(request.model) + await self._lazy_initialize(request) + def impl(): tokens = [] logprobs = [] @@ -242,7 +295,7 @@ class MetaReferenceInferenceImpl( stream=stream, logprobs=logprobs, ) - await self.check_model(request) + self.check_model(request) request = await request_with_localized_media(request) if self.config.create_distributed_process_group: @@ -257,6 +310,10 @@ class MetaReferenceInferenceImpl( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: + if self.model is None: + self.model = await self._setup_model(request.model) + await self._lazy_initialize(request) + def impl(): tokens = [] logprobs = [] @@ -294,6 +351,7 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: async with SEMAPHORE: + print("after SEMAPHORE") return impl() else: return impl() @@ -301,6 +359,10 @@ class MetaReferenceInferenceImpl( async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: + if self.model is None: + self.model = await self._setup_model(request.model) + await self._lazy_initialize(request) + def impl(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 7e7831185..ac676a202 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -7,7 +7,7 @@ import os from copy import deepcopy from functools import partial -from typing import Any, Generator +from typing import Any, Generator, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -34,8 +34,11 @@ class ModelRunner: raise ValueError(f"Unexpected task type {type(req)}") -def init_model_cb(config: MetaReferenceInferenceConfig): - llama = Llama.build(config) +def init_model_cb( + config: MetaReferenceInferenceConfig, + request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, +): + llama = Llama.build(config, request) return ModelRunner(llama) @@ -50,9 +53,21 @@ class LlamaModelParallelGenerator: clear at the callsite why we need to use a context manager. """ - def __init__(self, config: MetaReferenceInferenceConfig): + def __init__( + self, + config: MetaReferenceInferenceConfig, + request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, + ): + print("LlamaModelParallelGenerator init") self.config = config - self.model = resolve_model(self.config.model) + self.request = request + if config.model: + self.model = resolve_model(config.model) + elif request: + self.model = resolve_model(request.model) + else: + raise RuntimeError("you need to provide a model for inference") + # this is a hack because Agent's loop uses this to tokenize and check if input is too long # while the tool-use loop is going checkpoint_dir = model_checkpoint_dir(self.model) @@ -66,9 +81,15 @@ class LlamaModelParallelGenerator: self.__exit__(None, None, None) def __enter__(self): + print("enter LlamaModelParallelGenerator") + if self.config.model_parallel_size: + model_parallel_size = self.config.model_parallel_size + else: + model_parallel_size = resolve_model(self.model).pth_file_count + self.group = ModelParallelProcessGroup( - self.config.model_parallel_size, - init_model_cb=partial(init_model_cb, self.config), + model_parallel_size, + init_model_cb=partial(init_model_cb, self.config, self.request), ) self.group.start() return self diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index 553d02418..d204f98a4 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -27,8 +27,7 @@ def supported_inference_models() -> List[Model]: m for m in all_registered_models() if ( - m.model_family - in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3} + m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2} or is_supported_safety_model(m) ) ] diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index e8fdb10c2..af9b4a89c 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -16,7 +16,7 @@ providers: - provider_id: meta-reference-inference provider_type: inline::meta-reference config: - model: ${env.INFERENCE_MODEL} + # model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} memory: @@ -72,11 +72,11 @@ metadata_store: namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db -models: -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: meta-reference-inference - provider_model_id: null +models: [] +# - metadata: {} +# model_id: ${env.INFERENCE_MODEL} +# provider_id: meta-reference-inference +# provider_model_id: null shields: [] memory_banks: [] datasets: []