From a08fd8f3317f419a4962100af9f3da0f44064baa Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Sat, 28 Sep 2024 18:46:35 +0000 Subject: [PATCH] Add boilerplate for vllm inference provider Signed-off-by: Russell Bryant --- .../providers/adapters/inference/__init__.py | 5 --- llama_stack/providers/impls/vllm/__init__.py | 11 +++++++ llama_stack/providers/impls/vllm/config.py | 5 +++ llama_stack/providers/impls/vllm/vllm.py | 33 +++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 llama_stack/providers/impls/vllm/__init__.py create mode 100644 llama_stack/providers/impls/vllm/config.py create mode 100644 llama_stack/providers/impls/vllm/vllm.py diff --git a/llama_stack/providers/adapters/inference/__init__.py b/llama_stack/providers/adapters/inference/__init__.py index 756f351d8..e69de29bb 100644 --- a/llama_stack/providers/adapters/inference/__init__.py +++ b/llama_stack/providers/adapters/inference/__init__.py @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/impls/vllm/__init__.py b/llama_stack/providers/impls/vllm/__init__.py new file mode 100644 index 000000000..3d5a81ad9 --- /dev/null +++ b/llama_stack/providers/impls/vllm/__init__.py @@ -0,0 +1,11 @@ +from typing import Any + +from .config import VLLMConfig + + +async def get_provider_impl(config: VLLMConfig, _deps) -> Any: + from .vllm import VLLMInferenceImpl + + impl = VLLMInferenceImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/vllm/config.py b/llama_stack/providers/impls/vllm/config.py new file mode 100644 index 000000000..fe79767a8 --- /dev/null +++ b/llama_stack/providers/impls/vllm/config.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class VLLMConfig(BaseModel): + pass diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py new file mode 100644 index 000000000..1f6c83441 --- /dev/null +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -0,0 +1,33 @@ +import logging +from typing import Any + +from llama_stack.apis.inference.inference import CompletionResponse, CompletionResponseStreamChunk, LogProbConfig, ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse +from llama_stack.apis.inference import Inference + +from .config import VLLMConfig + +from llama_models.llama3.api.datatypes import InterleavedTextMedia, Message, ToolChoice, ToolDefinition, ToolPromptFormat + + +log = logging.getLogger(__name__) + + +class VLLMInferenceImpl(Inference): + """Inference implementation for vLLM.""" + def __init__(self, config: VLLMConfig): + self.config = config + + async def initialize(self): + log.info("Initializing vLLM inference adapter") + pass + + async def completion(self, model: str, content: InterleavedTextMedia, sampling_params: Any | None = ..., stream: bool | None = False, logprobs: LogProbConfig | None = None) -> CompletionResponse | CompletionResponseStreamChunk: + log.info("vLLM completion") + return None + + async def chat_completion(self, model: str, messages: list[Message], sampling_params: Any | None = ..., tools: list[ToolDefinition] | None = ..., tool_choice: ToolChoice | None = ..., tool_prompt_format: ToolPromptFormat | None = ..., stream: bool | None = False, logprobs: LogProbConfig | None = None) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: + log.info("vLLM chat completion") + return None + + async def embeddings(self, model: str, contents: list[InterleavedTextMedia]) -> EmbeddingsResponse: + log.info("vLLM embeddings")