temp commit

This commit is contained in:
Botao Chen 2024-12-12 21:44:03 -08:00
parent 8efe33646d
commit de44af1501
9 changed files with 153 additions and 53 deletions

View file

@ -90,6 +90,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
print("inference router")
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type model_id, provider_model_id, provider_id, metadata, model_type
) )

View file

@ -32,6 +32,7 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p) api = get_impl_api(p)
print("registering object with provider", api)
assert obj.provider_id != "remote", "Remote provider should not be registered" assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -169,6 +170,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object( async def register_object(
self, obj: RoutableObjectWithProvider self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ) -> RoutableObjectWithProvider:
# Get existing objects from registry # Get existing objects from registry
existing_obj = await self.dist_registry.get(obj.type, obj.identifier) existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
@ -181,7 +183,12 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id] p = self.impls_by_provider_id[obj.provider_id]
if obj is None:
print("obj is None")
registered_obj = await register_object_with_provider(obj, p) registered_obj = await register_object_with_provider(obj, p)
if registered_obj is None:
print("registered_obj is None")
# TODO: This needs to be fixed for all APIs once they return the registered object # TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj) await self.dist_registry.register(registered_obj)
@ -211,6 +218,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
print("register_model", model_id)
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
if provider_id is None: if provider_id is None:
@ -239,7 +247,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata=metadata, metadata=metadata,
model_type=model_type, model_type=model_type,
) )
if model is None:
print("model is None!!!")
print("before registered_model")
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
print("after registered_model")
return registered_model return registered_model
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:

View file

@ -15,6 +15,10 @@ async def get_provider_impl(
): ):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl
print("get_provider_impl")
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(config)
if config.model:
# pre-load the model if the model is in the config
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -10,16 +10,13 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403 from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: str = Field( model: Optional[str] = None
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1

View file

@ -79,6 +79,7 @@ class Llama:
config: Union[ config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
], ],
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
): ):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -87,10 +88,13 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
model = await self.model_store.get_model(config.model) if config.model:
base_model = model.metadata["base_model"] or self.model_id
self.model = resolve_model(base_model)
model = resolve_model(config.model) model = resolve_model(config.model)
elif request:
model = resolve_model(request.model)
else:
raise RuntimeError("you need to provide a model for inference")
llama_model = model.core_model_id.value llama_model = model.core_model_id.value
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():

View file

@ -11,6 +11,8 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model as LlamaStackModel
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.providers.utils.inference.model_registry import build_model_alias
@ -41,19 +43,16 @@ class MetaReferenceInferenceImpl(
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
print("MetaReferenceInferenceImpl init")
self.config = config self.config = config
self.model_id = config.model self.model = None
self.model_registry_helper = None
async def initialize(self) -> None: if config.model:
model = await self.model_store.get_model(self.model_id) model = resolve_model(config.model)
base_model = model.metadata["base_model"] or self.model_id if model is None:
self.model = resolve_model(base_model)
if self.model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model" f"Unknown model: {config.model}, Run `llama model list`"
) )
self.model_registry_helper = ModelRegistryHelper( self.model_registry_helper = ModelRegistryHelper(
[ [
build_model_alias( build_model_alias(
@ -62,28 +61,59 @@ class MetaReferenceInferenceImpl(
) )
], ],
) )
self.model = model
# verify that the checkpoint actually is for this model lol
else:
print("inference model isn't pre-loaded")
async def _setup_model(self, model_id: str) -> Optional[Model]:
model = resolve_model(model_id)
if model is None:
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
# self.model_registry_helper = ModelRegistryHelper(
# [
# build_model_alias(
# model.descriptor(),
# model.core_model_id.value,
# )
# ],
# )
# return await self.register_model(model)
return model
async def initialize(self) -> None:
if self.model is None:
raise RuntimeError("model hasn't been setup yet")
log.info(f"Loading model `{self.model.descriptor()}`") log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
else: else:
self.generator = Llama.build(self.config) self.generator = Llama.build(self.config)
async def _lazy_initialize(self, request) -> None:
if self.model is None:
raise RuntimeError("model hasn't been setup yet")
print(f"Lazy loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
# with LlamaModelParallelGenerator(self.config, request) as resouce:
self.generator = LlamaModelParallelGenerator(self.config, request)
self.generator.start()
else:
self.generator = Llama.build(self.config, request)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def check_model(self, request) -> None: def check_model(self, request) -> None:
request_model = await self.model_store.get_model(request.model) model = resolve_model(request.model)
base_model = request_model.metadata["base_model"] or request.model
model = resolve_model(base_model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model" f"Unknown model: {request.model}, Run `llama model list`"
) )
elif model.descriptor() != self.model.descriptor(): elif self.model and model.descriptor() != self.model.descriptor():
raise RuntimeError( raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}" f"Model mismatch: {request.model} != {self.model.descriptor()}"
) )
@ -91,8 +121,23 @@ class MetaReferenceInferenceImpl(
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model: async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
if self.model_registry_helper is None:
llama_model = resolve_model(model.identifier)
if llama_model is None:
raise RuntimeError(
f"Unknown model: {model.identifier}, Run `llama model list`"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
llama_model.descriptor(),
llama_model.core_model_id.value,
)
],
)
model = await self.model_registry_helper.register_model(model) model = await self.model_registry_helper.register_model(model)
print("model type", type(model))
if model.model_type == ModelType.embedding_model: if model.model_type == ModelType.embedding_model:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
return model return model
@ -117,7 +162,7 @@ class MetaReferenceInferenceImpl(
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
await self.check_model(request) self.check_model(request)
request = await request_with_localized_media(request) request = await request_with_localized_media(request)
if request.stream: if request.stream:
@ -126,6 +171,10 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
stop_reason = None stop_reason = None
@ -175,6 +224,10 @@ class MetaReferenceInferenceImpl(
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest self, request: CompletionRequest
) -> CompletionResponse: ) -> CompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -242,7 +295,7 @@ class MetaReferenceInferenceImpl(
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
await self.check_model(request) self.check_model(request)
request = await request_with_localized_media(request) request = await request_with_localized_media(request)
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
@ -257,6 +310,10 @@ class MetaReferenceInferenceImpl(
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -294,6 +351,7 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
print("after SEMAPHORE")
return impl() return impl()
else: else:
return impl() return impl()
@ -301,6 +359,10 @@ class MetaReferenceInferenceImpl(
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -7,7 +7,7 @@
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Generator from typing import Any, Generator, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -34,8 +34,11 @@ class ModelRunner:
raise ValueError(f"Unexpected task type {type(req)}") raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig): def init_model_cb(
llama = Llama.build(config) config: MetaReferenceInferenceConfig,
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
):
llama = Llama.build(config, request)
return ModelRunner(llama) return ModelRunner(llama)
@ -50,9 +53,21 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager. clear at the callsite why we need to use a context manager.
""" """
def __init__(self, config: MetaReferenceInferenceConfig): def __init__(
self,
config: MetaReferenceInferenceConfig,
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
):
print("LlamaModelParallelGenerator init")
self.config = config self.config = config
self.model = resolve_model(self.config.model) self.request = request
if config.model:
self.model = resolve_model(config.model)
elif request:
self.model = resolve_model(request.model)
else:
raise RuntimeError("you need to provide a model for inference")
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going # while the tool-use loop is going
checkpoint_dir = model_checkpoint_dir(self.model) checkpoint_dir = model_checkpoint_dir(self.model)
@ -66,9 +81,15 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None) self.__exit__(None, None, None)
def __enter__(self): def __enter__(self):
print("enter LlamaModelParallelGenerator")
if self.config.model_parallel_size:
model_parallel_size = self.config.model_parallel_size
else:
model_parallel_size = resolve_model(self.model).pth_file_count
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
self.config.model_parallel_size, model_parallel_size,
init_model_cb=partial(init_model_cb, self.config), init_model_cb=partial(init_model_cb, self.config, self.request),
) )
self.group.start() self.group.start()
return self return self

View file

@ -27,8 +27,7 @@ def supported_inference_models() -> List[Model]:
m m
for m in all_registered_models() for m in all_registered_models()
if ( if (
m.model_family m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m) or is_supported_safety_model(m)
) )
] ]

View file

@ -16,7 +16,7 @@ providers:
- provider_id: meta-reference-inference - provider_id: meta-reference-inference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.INFERENCE_MODEL} # model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
memory: memory:
@ -72,11 +72,11 @@ metadata_store:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models: models: []
- metadata: {} # - metadata: {}
model_id: ${env.INFERENCE_MODEL} # model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference # provider_id: meta-reference-inference
provider_model_id: null # provider_model_id: null
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []