forked from phoenix-oss/llama-stack-mirror
		
	llama-models should have extremely minimal cruft. Its sole purpose should be didactic -- show the simplest implementation of the llama models and document the prompt formats, etc. This PR is the complement to https://github.com/meta-llama/llama-models/pull/279 ## Test Plan Ensure all `llama` CLI `model` sub-commands work: ```bash llama model list llama model download --model-id ... llama model prompt-format -m ... ``` Ran tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/ LLAMA_STACK_CONFIG=fireworks pytest -s -v vector_io/ LLAMA_STACK_CONFIG=fireworks pytest -s -v agents/ ``` Create a fresh venv `uv venv && source .venv/bin/activate` and run `llama stack build --template fireworks --image-type venv` followed by `llama stack run together --image-type venv` <-- the server runs Also checked that the OpenAPI generator can run and there is no change in the generated files as a result. ```bash cd docs/openapi_generator sh run_openapi_generator.sh ```
		
			
				
	
	
		
			116 lines
		
	
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			116 lines
		
	
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import os
 | |
| from copy import deepcopy
 | |
| from functools import partial
 | |
| from typing import Any, Generator
 | |
| 
 | |
| from llama_models.llama3.api.chat_format import ChatFormat
 | |
| from llama_models.llama3.api.tokenizer import Tokenizer
 | |
| 
 | |
| from llama_stack.models.llama.datatypes import Model
 | |
| from llama_stack.models.llama.sku_list import resolve_model
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     ChatCompletionRequestWithRawContent,
 | |
|     CompletionRequestWithRawContent,
 | |
| )
 | |
| 
 | |
| from .config import MetaReferenceInferenceConfig
 | |
| from .generation import Llama, model_checkpoint_dir
 | |
| from .parallel_utils import ModelParallelProcessGroup
 | |
| 
 | |
| 
 | |
| class ModelRunner:
 | |
|     def __init__(self, llama):
 | |
|         self.llama = llama
 | |
| 
 | |
|     # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
 | |
|     def __call__(self, req: Any):
 | |
|         if isinstance(req, ChatCompletionRequestWithRawContent):
 | |
|             return self.llama.chat_completion(req)
 | |
|         elif isinstance(req, CompletionRequestWithRawContent):
 | |
|             return self.llama.completion(req)
 | |
|         else:
 | |
|             raise ValueError(f"Unexpected task type {type(req)}")
 | |
| 
 | |
| 
 | |
| def init_model_cb(
 | |
|     config: MetaReferenceInferenceConfig,
 | |
|     model_id: str,
 | |
|     llama_model: Model,
 | |
| ):
 | |
|     llama = Llama.build(config, model_id, llama_model)
 | |
|     return ModelRunner(llama)
 | |
| 
 | |
| 
 | |
| class LlamaModelParallelGenerator:
 | |
|     """
 | |
|     This abstraction exists so
 | |
|      - we can run model parallel code without needing to run the CLIs via torchrun
 | |
|      - this also enables use model parallel code within a notebook context.
 | |
| 
 | |
|     A Context Manager is used to ensure that the model parallel process is started and stopped
 | |
|     correctly. This does make the ergonomics a little awkward, because it isn't immediately
 | |
|     clear at the callsite why we need to use a context manager.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         config: MetaReferenceInferenceConfig,
 | |
|         model_id: str,
 | |
|         llama_model: Model,
 | |
|     ):
 | |
|         self.config = config
 | |
|         self.model_id = model_id
 | |
|         self.llama_model = llama_model
 | |
| 
 | |
|         # this is a hack because Agent's loop uses this to tokenize and check if input is too long
 | |
|         # while the tool-use loop is going
 | |
|         resolved_model = resolve_model(model_id)
 | |
|         if resolved_model is None:
 | |
|             # if the model is not a native llama model, get the default checkpoint_dir based on model id
 | |
|             checkpoint_dir = model_checkpoint_dir(model_id)
 | |
|         else:
 | |
|             # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
 | |
|             checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
 | |
|         tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
 | |
|         self.formatter = ChatFormat(Tokenizer(tokenizer_path))
 | |
| 
 | |
|     def start(self):
 | |
|         self.__enter__()
 | |
| 
 | |
|     def stop(self):
 | |
|         self.__exit__(None, None, None)
 | |
| 
 | |
|     def __enter__(self):
 | |
|         model_parallel_size = self.llama_model.pth_file_count
 | |
| 
 | |
|         self.group = ModelParallelProcessGroup(
 | |
|             model_parallel_size,
 | |
|             init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
 | |
|         )
 | |
|         self.group.start()
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_value, exc_traceback):
 | |
|         self.group.stop()
 | |
| 
 | |
|     def completion(
 | |
|         self,
 | |
|         request: CompletionRequestWithRawContent,
 | |
|     ) -> Generator:
 | |
|         req_obj = deepcopy(request)
 | |
|         gen = self.group.run_inference(req_obj)
 | |
|         yield from gen
 | |
| 
 | |
|     def chat_completion(
 | |
|         self,
 | |
|         request: ChatCompletionRequestWithRawContent,
 | |
|     ) -> Generator:
 | |
|         req_obj = deepcopy(request)
 | |
|         gen = self.group.run_inference(req_obj)
 | |
|         yield from gen
 |