temp commit

This commit is contained in:
Botao Chen 2024-12-16 16:44:15 -08:00
parent de44af1501
commit 35b1a6f2dc
7 changed files with 54 additions and 113 deletions

View file

@ -74,6 +74,7 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
print("InferenceRouter init")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -18,7 +18,6 @@ async def get_provider_impl(
print("get_provider_impl") print("get_provider_impl")
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(config)
if config.model:
# pre-load the model if the model is in the config print("after MetaReferenceInferenceImpl")
await impl.initialize()
return impl return impl

View file

@ -16,7 +16,9 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: Optional[str] = None model: Optional[str] = (
None # this is a placeholder to indicate inference model id, not actually being used
)
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1

View file

@ -79,7 +79,7 @@ class Llama:
config: Union[ config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
], ],
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, model_id: str,
): ):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -88,12 +88,7 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
if config.model: model = resolve_model(model_id)
model = resolve_model(config.model)
elif request:
model = resolve_model(request.model)
else:
raise RuntimeError("you need to provide a model for inference")
llama_model = model.core_model_id.value llama_model = model.core_model_id.value

View file

@ -43,90 +43,47 @@ class MetaReferenceInferenceImpl(
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
print("MetaReferenceInferenceImpl init")
self.config = config self.config = config
self.model = None self.model = None
self.model_registry_helper = None
if config.model:
model = resolve_model(config.model)
if model is None:
raise RuntimeError(
f"Unknown model: {config.model}, Run `llama model list`"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
self.model = model
# verify that the checkpoint actually is for this model lol
else:
print("inference model isn't pre-loaded")
async def _setup_model(self, model_id: str) -> Optional[Model]: async def initialize(self, model_id) -> None:
model = resolve_model(model_id) log.info(f"Loading model `{model_id}`")
if model is None:
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
# self.model_registry_helper = ModelRegistryHelper(
# [
# build_model_alias(
# model.descriptor(),
# model.core_model_id.value,
# )
# ],
# )
# return await self.register_model(model)
return model
async def initialize(self) -> None:
if self.model is None:
raise RuntimeError("model hasn't been setup yet")
log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config, model_id)
self.generator.start() self.generator.start()
else: else:
self.generator = Llama.build(self.config) self.generator = Llama.build(self.config, model_id)
async def _lazy_initialize(self, request) -> None: self.model = model_id
if self.model is None:
raise RuntimeError("model hasn't been setup yet")
print(f"Lazy loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
# with LlamaModelParallelGenerator(self.config, request) as resouce:
self.generator = LlamaModelParallelGenerator(self.config, request)
self.generator.start()
else:
self.generator = Llama.build(self.config, request)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
def check_model(self, request) -> None: def check_model(self, request) -> None:
model = resolve_model(request.model) if self.model is None:
if model is None: raise RuntimeError(
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
)
inference_model = resolve_model(self.model)
requested_model = resolve_model(request.model)
if requested_model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`" f"Unknown model: {request.model}, Run `llama model list`"
) )
elif self.model and model.descriptor() != self.model.descriptor(): elif requested_model.descriptor() != inference_model.descriptor():
raise RuntimeError( raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}" f"Model mismatch: {request.model} != {inference_model.descriptor()}"
) )
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel: async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
if self.model_registry_helper is None:
llama_model = resolve_model(model.identifier) llama_model = resolve_model(model.identifier)
if llama_model is None: if llama_model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {model.identifier}, Run `llama model list`" f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list"
) )
self.model_registry_helper = ModelRegistryHelper( self.model_registry_helper = ModelRegistryHelper(
[ [
@ -140,6 +97,14 @@ class MetaReferenceInferenceImpl(
print("model type", type(model)) print("model type", type(model))
if model.model_type == ModelType.embedding_model: if model.model_type == ModelType.embedding_model:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
if (
model.metadata
and "skip_initialize" in model.metadata
and model.metadata["skip_initialize"]
):
return model
await self.initialize(model.identifier)
return model return model
async def completion( async def completion(
@ -171,10 +136,6 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
stop_reason = None stop_reason = None
@ -224,10 +185,6 @@ class MetaReferenceInferenceImpl(
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest self, request: CompletionRequest
) -> CompletionResponse: ) -> CompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -310,10 +267,6 @@ class MetaReferenceInferenceImpl(
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
tokens = [] tokens = []
logprobs = [] logprobs = []
@ -359,10 +312,6 @@ class MetaReferenceInferenceImpl(
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl(): def impl():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(

View file

@ -7,7 +7,7 @@
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Generator, Optional, Union from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -36,9 +36,9 @@ class ModelRunner:
def init_model_cb( def init_model_cb(
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, model_id: str,
): ):
llama = Llama.build(config, request) llama = Llama.build(config, model_id)
return ModelRunner(llama) return ModelRunner(llama)
@ -56,17 +56,12 @@ class LlamaModelParallelGenerator:
def __init__( def __init__(
self, self,
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None, model_id: str,
): ):
print("LlamaModelParallelGenerator init") print("LlamaModelParallelGenerator init")
self.config = config self.config = config
self.request = request self.model_id = model_id
if config.model: self.model = resolve_model(model_id)
self.model = resolve_model(config.model)
elif request:
self.model = resolve_model(request.model)
else:
raise RuntimeError("you need to provide a model for inference")
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long
# while the tool-use loop is going # while the tool-use loop is going
@ -89,7 +84,7 @@ class LlamaModelParallelGenerator:
self.group = ModelParallelProcessGroup( self.group = ModelParallelProcessGroup(
model_parallel_size, model_parallel_size,
init_model_cb=partial(init_model_cb, self.config, self.request), init_model_cb=partial(init_model_cb, self.config, self.model_id),
) )
self.group.start() self.group.start()
return self return self

View file

@ -16,7 +16,7 @@ providers:
- provider_id: meta-reference-inference - provider_id: meta-reference-inference
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
# model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL} # please make sure your inference model here is added as resource
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
memory: memory:
@ -73,10 +73,10 @@ metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models: [] models: []
# - metadata: {} - metadata: {}
# model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
# provider_id: meta-reference-inference provider_id: meta-reference-inference
# provider_model_id: null provider_model_id: null
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []