mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
temp_commit
This commit is contained in:
parent
53b3a1e345
commit
8efe33646d
2 changed files with 24 additions and 11 deletions
|
@ -87,6 +87,9 @@ class Llama:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
|
model = await self.model_store.get_model(config.model)
|
||||||
|
base_model = model.metadata["base_model"] or self.model_id
|
||||||
|
self.model = resolve_model(base_model)
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
llama_model = model.core_model_id.value
|
llama_model = model.core_model_id.value
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -41,9 +42,18 @@ class MetaReferenceInferenceImpl(
|
||||||
):
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
self.model_id = config.model
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
async def initialize(self) -> None:
|
||||||
|
model = await self.model_store.get_model(self.model_id)
|
||||||
|
base_model = model.metadata["base_model"] or self.model_id
|
||||||
|
self.model = resolve_model(base_model)
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model"
|
||||||
|
)
|
||||||
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
|
@ -52,11 +62,9 @@ class MetaReferenceInferenceImpl(
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
self.model = model
|
|
||||||
# verify that the checkpoint actually is for this model lol
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
|
@ -67,11 +75,13 @@ class MetaReferenceInferenceImpl(
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
async def check_model(self, request) -> None:
|
||||||
model = resolve_model(request.model)
|
request_model = await self.model_store.get_model(request.model)
|
||||||
|
base_model = request_model.metadata["base_model"] or request.model
|
||||||
|
model = resolve_model(base_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unknown model: {request.model}, Run `llama model list`"
|
f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model"
|
||||||
)
|
)
|
||||||
elif model.descriptor() != self.model.descriptor():
|
elif model.descriptor() != self.model.descriptor():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -107,7 +117,7 @@ class MetaReferenceInferenceImpl(
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
await self.check_model(request)
|
||||||
request = await request_with_localized_media(request)
|
request = await request_with_localized_media(request)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
|
@ -232,7 +242,7 @@ class MetaReferenceInferenceImpl(
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
await self.check_model(request)
|
||||||
request = await request_with_localized_media(request)
|
request = await request_with_localized_media(request)
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue