mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 21:59:49 +00:00
temp commit
This commit is contained in:
parent
de44af1501
commit
35b1a6f2dc
7 changed files with 54 additions and 113 deletions
|
|
@ -43,103 +43,68 @@ class MetaReferenceInferenceImpl(
|
|||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
print("MetaReferenceInferenceImpl init")
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.model_registry_helper = None
|
||||
if config.model:
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {config.model}, Run `llama model list`"
|
||||
)
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
else:
|
||||
print("inference model isn't pre-loaded")
|
||||
|
||||
async def _setup_model(self, model_id: str) -> Optional[Model]:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
|
||||
# self.model_registry_helper = ModelRegistryHelper(
|
||||
# [
|
||||
# build_model_alias(
|
||||
# model.descriptor(),
|
||||
# model.core_model_id.value,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
# return await self.register_model(model)
|
||||
return model
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self.model is None:
|
||||
raise RuntimeError("model hasn't been setup yet")
|
||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||
async def initialize(self, model_id) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator = LlamaModelParallelGenerator(self.config, model_id)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
self.generator = Llama.build(self.config, model_id)
|
||||
|
||||
async def _lazy_initialize(self, request) -> None:
|
||||
if self.model is None:
|
||||
raise RuntimeError("model hasn't been setup yet")
|
||||
print(f"Lazy loading model `{self.model.descriptor()}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
# with LlamaModelParallelGenerator(self.config, request) as resouce:
|
||||
self.generator = LlamaModelParallelGenerator(self.config, request)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config, request)
|
||||
self.model = model_id
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
if self.model is None:
|
||||
raise RuntimeError(
|
||||
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
inference_model = resolve_model(self.model)
|
||||
requested_model = resolve_model(request.model)
|
||||
if requested_model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif self.model and model.descriptor() != self.model.descriptor():
|
||||
elif requested_model.descriptor() != inference_model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
f"Model mismatch: {request.model} != {inference_model.descriptor()}"
|
||||
)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
||||
if self.model_registry_helper is None:
|
||||
llama_model = resolve_model(model.identifier)
|
||||
if llama_model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {model.identifier}, Run `llama model list`"
|
||||
)
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
llama_model = resolve_model(model.identifier)
|
||||
if llama_model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list"
|
||||
)
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
print("model type", type(model))
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
if (
|
||||
model.metadata
|
||||
and "skip_initialize" in model.metadata
|
||||
and model.metadata["skip_initialize"]
|
||||
):
|
||||
return model
|
||||
await self.initialize(model.identifier)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
|
|
@ -171,10 +136,6 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
|
|
@ -224,10 +185,6 @@ class MetaReferenceInferenceImpl(
|
|||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
|
@ -310,10 +267,6 @@ class MetaReferenceInferenceImpl(
|
|||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
|
@ -359,10 +312,6 @@ class MetaReferenceInferenceImpl(
|
|||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue