mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
temp commit
This commit is contained in:
parent
de44af1501
commit
35b1a6f2dc
7 changed files with 54 additions and 113 deletions
|
@ -74,6 +74,7 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print("InferenceRouter init")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
|
@ -18,7 +18,6 @@ async def get_provider_impl(
|
||||||
print("get_provider_impl")
|
print("get_provider_impl")
|
||||||
|
|
||||||
impl = MetaReferenceInferenceImpl(config)
|
impl = MetaReferenceInferenceImpl(config)
|
||||||
if config.model:
|
|
||||||
# pre-load the model if the model is in the config
|
print("after MetaReferenceInferenceImpl")
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,7 +16,9 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceConfig(BaseModel):
|
class MetaReferenceInferenceConfig(BaseModel):
|
||||||
model: Optional[str] = None
|
model: Optional[str] = (
|
||||||
|
None # this is a placeholder to indicate inference model id, not actually being used
|
||||||
|
)
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int = 4096
|
max_seq_len: int = 4096
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
|
|
@ -79,7 +79,7 @@ class Llama:
|
||||||
config: Union[
|
config: Union[
|
||||||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
],
|
],
|
||||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
model_id: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
@ -88,12 +88,7 @@ class Llama:
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
This method initializes the distributed process group, sets the device to CUDA,
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
if config.model:
|
model = resolve_model(model_id)
|
||||||
model = resolve_model(config.model)
|
|
||||||
elif request:
|
|
||||||
model = resolve_model(request.model)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("you need to provide a model for inference")
|
|
||||||
|
|
||||||
llama_model = model.core_model_id.value
|
llama_model = model.core_model_id.value
|
||||||
|
|
||||||
|
|
|
@ -43,103 +43,68 @@ class MetaReferenceInferenceImpl(
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
print("MetaReferenceInferenceImpl init")
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_registry_helper = None
|
|
||||||
if config.model:
|
|
||||||
model = resolve_model(config.model)
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unknown model: {config.model}, Run `llama model list`"
|
|
||||||
)
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
|
||||||
[
|
|
||||||
build_model_alias(
|
|
||||||
model.descriptor(),
|
|
||||||
model.core_model_id.value,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self.model = model
|
|
||||||
# verify that the checkpoint actually is for this model lol
|
|
||||||
else:
|
|
||||||
print("inference model isn't pre-loaded")
|
|
||||||
|
|
||||||
async def _setup_model(self, model_id: str) -> Optional[Model]:
|
async def initialize(self, model_id) -> None:
|
||||||
model = resolve_model(model_id)
|
log.info(f"Loading model `{model_id}`")
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {model_id}, Run `llama model list`")
|
|
||||||
# self.model_registry_helper = ModelRegistryHelper(
|
|
||||||
# [
|
|
||||||
# build_model_alias(
|
|
||||||
# model.descriptor(),
|
|
||||||
# model.core_model_id.value,
|
|
||||||
# )
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return await self.register_model(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
if self.model is None:
|
|
||||||
raise RuntimeError("model hasn't been setup yet")
|
|
||||||
log.info(f"Loading model `{self.model.descriptor()}`")
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(self.config)
|
self.generator = LlamaModelParallelGenerator(self.config, model_id)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config)
|
self.generator = Llama.build(self.config, model_id)
|
||||||
|
|
||||||
async def _lazy_initialize(self, request) -> None:
|
self.model = model_id
|
||||||
if self.model is None:
|
|
||||||
raise RuntimeError("model hasn't been setup yet")
|
|
||||||
print(f"Lazy loading model `{self.model.descriptor()}`")
|
|
||||||
if self.config.create_distributed_process_group:
|
|
||||||
# with LlamaModelParallelGenerator(self.config, request) as resouce:
|
|
||||||
self.generator = LlamaModelParallelGenerator(self.config, request)
|
|
||||||
self.generator.start()
|
|
||||||
else:
|
|
||||||
self.generator = Llama.build(self.config, request)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
def check_model(self, request) -> None:
|
||||||
model = resolve_model(request.model)
|
if self.model is None:
|
||||||
if model is None:
|
raise RuntimeError(
|
||||||
|
"Inference model hasn't been initialized yet, please register your requested model or add your model in the resouces first"
|
||||||
|
)
|
||||||
|
inference_model = resolve_model(self.model)
|
||||||
|
requested_model = resolve_model(request.model)
|
||||||
|
if requested_model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unknown model: {request.model}, Run `llama model list`"
|
f"Unknown model: {request.model}, Run `llama model list`"
|
||||||
)
|
)
|
||||||
elif self.model and model.descriptor() != self.model.descriptor():
|
elif requested_model.descriptor() != inference_model.descriptor():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
f"Model mismatch: {request.model} != {inference_model.descriptor()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
async def register_model(self, model: LlamaStackModel) -> LlamaStackModel:
|
||||||
if self.model_registry_helper is None:
|
llama_model = resolve_model(model.identifier)
|
||||||
llama_model = resolve_model(model.identifier)
|
if llama_model is None:
|
||||||
if llama_model is None:
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
f"Unknown model: {model.identifier}, Please make sure your model is in llama-models SKU list"
|
||||||
f"Unknown model: {model.identifier}, Run `llama model list`"
|
|
||||||
)
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
|
||||||
[
|
|
||||||
build_model_alias(
|
|
||||||
llama_model.descriptor(),
|
|
||||||
llama_model.core_model_id.value,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
|
[
|
||||||
|
build_model_alias(
|
||||||
|
llama_model.descriptor(),
|
||||||
|
llama_model.core_model_id.value,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
model = await self.model_registry_helper.register_model(model)
|
model = await self.model_registry_helper.register_model(model)
|
||||||
print("model type", type(model))
|
print("model type", type(model))
|
||||||
if model.model_type == ModelType.embedding_model:
|
if model.model_type == ModelType.embedding_model:
|
||||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
|
||||||
|
if (
|
||||||
|
model.metadata
|
||||||
|
and "skip_initialize" in model.metadata
|
||||||
|
and model.metadata["skip_initialize"]
|
||||||
|
):
|
||||||
|
return model
|
||||||
|
await self.initialize(model.identifier)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -171,10 +136,6 @@ class MetaReferenceInferenceImpl(
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
if self.model is None:
|
|
||||||
self.model = await self._setup_model(request.model)
|
|
||||||
await self._lazy_initialize(request)
|
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
|
@ -224,10 +185,6 @@ class MetaReferenceInferenceImpl(
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
if self.model is None:
|
|
||||||
self.model = await self._setup_model(request.model)
|
|
||||||
await self._lazy_initialize(request)
|
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -310,10 +267,6 @@ class MetaReferenceInferenceImpl(
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
if self.model is None:
|
|
||||||
self.model = await self._setup_model(request.model)
|
|
||||||
await self._lazy_initialize(request)
|
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -359,10 +312,6 @@ class MetaReferenceInferenceImpl(
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if self.model is None:
|
|
||||||
self.model = await self._setup_model(request.model)
|
|
||||||
await self._lazy_initialize(request)
|
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Generator, Optional, Union
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -36,9 +36,9 @@ class ModelRunner:
|
||||||
|
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
model_id: str,
|
||||||
):
|
):
|
||||||
llama = Llama.build(config, request)
|
llama = Llama.build(config, model_id)
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,17 +56,12 @@ class LlamaModelParallelGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
request: Optional[Union[CompletionRequest, ChatCompletionRequest]] = None,
|
model_id: str,
|
||||||
):
|
):
|
||||||
print("LlamaModelParallelGenerator init")
|
print("LlamaModelParallelGenerator init")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.request = request
|
self.model_id = model_id
|
||||||
if config.model:
|
self.model = resolve_model(model_id)
|
||||||
self.model = resolve_model(config.model)
|
|
||||||
elif request:
|
|
||||||
self.model = resolve_model(request.model)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("you need to provide a model for inference")
|
|
||||||
|
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
# while the tool-use loop is going
|
# while the tool-use loop is going
|
||||||
|
@ -89,7 +84,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
self.group = ModelParallelProcessGroup(
|
self.group = ModelParallelProcessGroup(
|
||||||
model_parallel_size,
|
model_parallel_size,
|
||||||
init_model_cb=partial(init_model_cb, self.config, self.request),
|
init_model_cb=partial(init_model_cb, self.config, self.model_id),
|
||||||
)
|
)
|
||||||
self.group.start()
|
self.group.start()
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -16,7 +16,7 @@ providers:
|
||||||
- provider_id: meta-reference-inference
|
- provider_id: meta-reference-inference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
# model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL} # please make sure your inference model here is added as resource
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
memory:
|
memory:
|
||||||
|
@ -73,10 +73,10 @@ metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||||
models: []
|
models: []
|
||||||
# - metadata: {}
|
- metadata: {}
|
||||||
# model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
# provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
# provider_model_id: null
|
provider_model_id: null
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue