temp_commit

This commit is contained in:
Botao Chen 2024-12-12 17:15:05 -08:00
parent 53b3a1e345
commit 8efe33646d
2 changed files with 24 additions and 11 deletions

View file

@ -87,6 +87,9 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA, This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
model = await self.model_store.get_model(config.model)
base_model = model.metadata["base_model"] or self.model_id
self.model = resolve_model(base_model)
model = resolve_model(config.model) model = resolve_model(config.model)
llama_model = model.core_model_id.value llama_model = model.core_model_id.value

View file

@ -24,6 +24,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url, convert_image_media_to_url,
request_has_media, request_has_media,
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
@ -41,9 +42,18 @@ class MetaReferenceInferenceImpl(
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) self.model_id = config.model
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") async def initialize(self) -> None:
model = await self.model_store.get_model(self.model_id)
base_model = model.metadata["base_model"] or self.model_id
self.model = resolve_model(base_model)
if self.model is None:
raise RuntimeError(
f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model"
)
self.model_registry_helper = ModelRegistryHelper( self.model_registry_helper = ModelRegistryHelper(
[ [
build_model_alias( build_model_alias(
@ -52,11 +62,9 @@ class MetaReferenceInferenceImpl(
) )
], ],
) )
self.model = model
# verify that the checkpoint actually is for this model lol
async def initialize(self) -> None:
log.info(f"Loading model `{self.model.descriptor()}`") log.info(f"Loading model `{self.model.descriptor()}`")
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config) self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start() self.generator.start()
@ -67,11 +75,13 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
def check_model(self, request) -> None: async def check_model(self, request) -> None:
model = resolve_model(request.model) request_model = await self.model_store.get_model(request.model)
base_model = request_model.metadata["base_model"] or request.model
model = resolve_model(base_model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`" f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model"
) )
elif model.descriptor() != self.model.descriptor(): elif model.descriptor() != self.model.descriptor():
raise RuntimeError( raise RuntimeError(
@ -107,7 +117,7 @@ class MetaReferenceInferenceImpl(
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) await self.check_model(request)
request = await request_with_localized_media(request) request = await request_with_localized_media(request)
if request.stream: if request.stream:
@ -232,7 +242,7 @@ class MetaReferenceInferenceImpl(
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) await self.check_model(request)
request = await request_with_localized_media(request) request = await request_with_localized_media(request)
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group: