mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
8efe33646d
commit
de44af1501
9 changed files with 153 additions and 53 deletions
|
@ -90,6 +90,7 @@ class InferenceRouter(Inference):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
print("inference router")
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
|
|
|
@ -32,6 +32,7 @@ def get_impl_api(p: Any) -> Api:
|
|||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
|
||||
api = get_impl_api(p)
|
||||
print("registering object with provider", api)
|
||||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
|
@ -169,6 +170,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
|
||||
# Get existing objects from registry
|
||||
existing_obj = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
|
@ -181,7 +183,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
if obj is None:
|
||||
print("obj is None")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
if registered_obj is None:
|
||||
print("registered_obj is None")
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
|
@ -211,6 +218,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Model:
|
||||
print("register_model", model_id)
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
if provider_id is None:
|
||||
|
@ -239,7 +247,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
if model is None:
|
||||
print("model is None!!!")
|
||||
print("before registered_model")
|
||||
registered_model = await self.register_object(model)
|
||||
print("after registered_model")
|
||||
return registered_model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
|
@ -15,6 +15,10 @@ async def get_provider_impl(
|
|||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
print("get_provider_impl")
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
if config.model:
|
||||
# pre-load the model if the model is in the config
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -10,16 +10,13 @@ from llama_models.datatypes import * # noqa: F403
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
model: Optional[str] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
|
|
|
@ -79,6 +79,7 @@ class Llama:
|
|||
config: Union[
|
||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
],
|
||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
@ -87,10 +88,13 @@ class Llama:
|
|||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = await self.model_store.get_model(config.model)
|
||||
base_model = model.metadata["base_model"] or self.model_id
|
||||
self.model = resolve_model(base_model)
|
||||
model = resolve_model(config.model)
|
||||
if config.model:
|
||||
model = resolve_model(config.model)
|
||||
elif request:
|
||||
model = resolve_model(request.model)
|
||||
else:
|
||||
raise RuntimeError("you need to provide a model for inference")
|
||||
|
||||
llama_model = model.core_model_id.value
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
|
|
|
@ -11,6 +11,8 @@ from typing import AsyncGenerator, List
|
|||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import Model as LlamaStackModel
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
|
@ -41,49 +43,77 @@ class MetaReferenceInferenceImpl(
|
|||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
print("MetaReferenceInferenceImpl init")
|
||||
self.config = config
|
||||
self.model_id = config.model
|
||||
self.model = None
|
||||
self.model_registry_helper = None
|
||||
if config.model:
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {config.model}, Run `llama model list`"
|
||||
)
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
else:
|
||||
print("inference model isn't pre-loaded")
|
||||
|
||||
async def _setup_model(self, model_id: str) -> Optional[Model]:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
|
||||
# self.model_registry_helper = ModelRegistryHelper(
|
||||
# [
|
||||
# build_model_alias(
|
||||
# model.descriptor(),
|
||||
# model.core_model_id.value,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
# return await self.register_model(model)
|
||||
return model
|
||||
|
||||
async def initialize(self) -> None:
|
||||
model = await self.model_store.get_model(self.model_id)
|
||||
base_model = model.metadata["base_model"] or self.model_id
|
||||
self.model = resolve_model(base_model)
|
||||
|
||||
if self.model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {self.model_id}, Run please check if the model or base_Model is a native llama model"
|
||||
)
|
||||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
raise RuntimeError("model hasn't been setup yet")
|
||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
|
||||
async def _lazy_initialize(self, request) -> None:
|
||||
if self.model is None:
|
||||
raise RuntimeError("model hasn't been setup yet")
|
||||
print(f"Lazy loading model `{self.model.descriptor()}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
# with LlamaModelParallelGenerator(self.config, request) as resouce:
|
||||
self.generator = LlamaModelParallelGenerator(self.config, request)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config, request)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def check_model(self, request) -> None:
|
||||
request_model = await self.model_store.get_model(request.model)
|
||||
base_model = request_model.metadata["base_model"] or request.model
|
||||
model = resolve_model(base_model)
|
||||
def check_model(self, request) -> None:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run please check if the model or base_Model is a native llama model"
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
elif self.model and model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
@ -91,8 +121,23 @@ class MetaReferenceInferenceImpl(
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
||||
if self.model_registry_helper is None:
|
||||
llama_model = resolve_model(model.identifier)
|
||||
if llama_model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {model.identifier}, Run `llama model list`"
|
||||
)
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
print("model type", type(model))
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
@ -117,7 +162,7 @@ class MetaReferenceInferenceImpl(
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
await self.check_model(request)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if request.stream:
|
||||
|
@ -126,6 +171,10 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
|
@ -175,6 +224,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -242,7 +295,7 @@ class MetaReferenceInferenceImpl(
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
await self.check_model(request)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
|
@ -257,6 +310,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -294,6 +351,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
print("after SEMAPHORE")
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
@ -301,6 +359,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
if self.model is None:
|
||||
self.model = await self._setup_model(request.model)
|
||||
await self._lazy_initialize(request)
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Generator, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
@ -34,8 +34,11 @@ class ModelRunner:
|
|||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||
llama = Llama.build(config)
|
||||
def init_model_cb(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
||||
):
|
||||
llama = Llama.build(config, request)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
@ -50,9 +53,21 @@ class LlamaModelParallelGenerator:
|
|||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: MetaReferenceInferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
||||
):
|
||||
print("LlamaModelParallelGenerator init")
|
||||
self.config = config
|
||||
self.model = resolve_model(self.config.model)
|
||||
self.request = request
|
||||
if config.model:
|
||||
self.model = resolve_model(config.model)
|
||||
elif request:
|
||||
self.model = resolve_model(request.model)
|
||||
else:
|
||||
raise RuntimeError("you need to provide a model for inference")
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
|
@ -66,9 +81,15 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
print("enter LlamaModelParallelGenerator")
|
||||
if self.config.model_parallel_size:
|
||||
model_parallel_size = self.config.model_parallel_size
|
||||
else:
|
||||
model_parallel_size = resolve_model(self.model).pth_file_count
|
||||
|
||||
self.group = ModelParallelProcessGroup(
|
||||
self.config.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config, self.request),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
|
|
@ -27,8 +27,7 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family
|
||||
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
|||
- provider_id: meta-reference-inference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
model: ${env.INFERENCE_MODEL}
|
||||
# model: ${env.INFERENCE_MODEL}
|
||||
max_seq_len: 4096
|
||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||
memory:
|
||||
|
@ -72,11 +72,11 @@ metadata_store:
|
|||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: meta-reference-inference
|
||||
provider_model_id: null
|
||||
models: []
|
||||
# - metadata: {}
|
||||
# model_id: ${env.INFERENCE_MODEL}
|
||||
# provider_id: meta-reference-inference
|
||||
# provider_model_id: null
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue