temp commit

This commit is contained in:
Botao Chen 2024-12-12 21:44:03 -08:00
parent 8efe33646d
commit de44af1501
9 changed files with 153 additions and 53 deletions

View file

@ -11,6 +11,8 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model as LlamaStackModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
@ -41,49 +43,77 @@ class MetaReferenceInferenceImpl(
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
print("MetaReferenceInferenceImpl init")
self.config = config
self.model_id = config.model
self.model = None
self.model_registry_helper = None
if config.model:
model = resolve_model(config.model)
if model is None:
raise RuntimeError(
f"Unknown model: {config.model}, Run `llama model list`"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
self.model = model
# verify that the checkpoint actually is for this model lol
else:
print("inference model isn't pre-loaded")
async def _setup_model(self, model_id: str) -> Optional[Model]:
model = resolve_model(model_id)
if model is None:
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
# self.model_registry_helper = ModelRegistryHelper(
# [
# build_model_alias(
# model.descriptor(),
# model.core_model_id.value,
# )
# ],
# )
# return await self.register_model(model)
return model
async def initialize(self) -> None:
model = await self.model_store.get_model(self.model_id)
base_model = model.metadata["base_model"] or self.model_id
self.model = resolve_model(base_model)
if self.model is None:
raise RuntimeError(
f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
raise RuntimeError("model hasn't been setup yet")
log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def _lazy_initialize(self, request) -> None:
if self.model is None:
raise RuntimeError("model hasn't been setup yet")
print(f"Lazy loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group:
# with LlamaModelParallelGenerator(self.config, request) as resouce:
self.generator = LlamaModelParallelGenerator(self.config, request)
self.generator.start()
else:
self.generator = Llama.build(self.config, request)
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
async def check_model(self, request) -> None:
request_model = await self.model_store.get_model(request.model)
base_model = request_model.metadata["base_model"] or request.model
model = resolve_model(base_model)
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model"
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
elif self.model and model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
@ -91,8 +121,23 @@ class MetaReferenceInferenceImpl(
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
if self.model_registry_helper is None:
llama_model = resolve_model(model.identifier)
if llama_model is None:
raise RuntimeError(
f"Unknown model: {model.identifier}, Run `llama model list`"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
llama_model.descriptor(),
llama_model.core_model_id.value,
)
],
)
model = await self.model_registry_helper.register_model(model)
print("model type", type(model))
if model.model_type == ModelType.embedding_model:
self._load_sentence_transformer_model(model.provider_resource_id)
return model
@ -117,7 +162,7 @@ class MetaReferenceInferenceImpl(
stream=stream,
logprobs=logprobs,
)
await self.check_model(request)
self.check_model(request)
request = await request_with_localized_media(request)
if request.stream:
@ -126,6 +171,10 @@ class MetaReferenceInferenceImpl(
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl():
stop_reason = None
@ -175,6 +224,10 @@ class MetaReferenceInferenceImpl(
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl():
tokens = []
logprobs = []
@ -242,7 +295,7 @@ class MetaReferenceInferenceImpl(
stream=stream,
logprobs=logprobs,
)
await self.check_model(request)
self.check_model(request)
request = await request_with_localized_media(request)
if self.config.create_distributed_process_group:
@ -257,6 +310,10 @@ class MetaReferenceInferenceImpl(
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl():
tokens = []
logprobs = []
@ -294,6 +351,7 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
async with SEMAPHORE:
print("after SEMAPHORE")
return impl()
else:
return impl()
@ -301,6 +359,10 @@ class MetaReferenceInferenceImpl(
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
if self.model is None:
self.model = await self._setup_model(request.model)
await self._lazy_initialize(request)
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(