Merge branch 'main' into vllm

This commit is contained in:
Fred Reiss 2025-01-08 15:47:58 -08:00 committed by GitHub
commit 73fede90a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
175 changed files with 7948 additions and 876 deletions

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent,
)
from .config import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)

View file

@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
@ -27,9 +30,9 @@ class ModelRunner:
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)

View file

@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .generation import TokenResult
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: Union[CompletionRequest, ChatCompletionRequest]
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
@ -264,9 +267,6 @@ def launch_dist_group(
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
@ -315,7 +315,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
_response = request_socket.recv()
log.info("Loaded model...")
return request_socket, process
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
self.started = False
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -7,10 +7,10 @@
import logging
import os
import uuid
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,