temp commit

This commit is contained in:
Botao Chen 2024-12-16 21:43:30 -08:00
parent 30f6eb282f
commit 81e1957446
10 changed files with 54 additions and 39 deletions

View file

@ -46,13 +46,16 @@ class MetaReferenceInferenceImpl(
self.config = config
self.model = None
async def initialize(self, model_id) -> None:
async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config, model_id)
print("I reach create_distributed_process_group")
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id)
self.generator = Llama.build(self.config, model_id, llama_model)
self.model = model_id
@ -65,26 +68,27 @@ class MetaReferenceInferenceImpl(
raise RuntimeError(
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
)
inference_model = resolve_model(self.model)
requested_model = resolve_model(request.model)
if requested_model is None:
if request.model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif requested_model.descriptor() != inference_model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {inference_model.descriptor()}"
)
elif request.model != self.model:
raise RuntimeError(f"Model mismatch: {request.model} != {self.model}")
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
llama_model = resolve_model(model.identifier)
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
else resolve_model(model.identifier)
)
if llama_model is None:
raise RuntimeError(
f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list"
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
)
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
@ -94,6 +98,7 @@ class MetaReferenceInferenceImpl(
],
)
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
@ -103,7 +108,7 @@ class MetaReferenceInferenceImpl(
and model.metadata["skip_initialize"]
):
return model
await self.initialize(model.identifier)
await self.initialize(model.identifier, llama_model)
return model
async def completion(