Add boilerplate for vllm inference provider

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2024-09-28 18:46:35 +00:00
parent 8d41e6caa9
commit a08fd8f331
4 changed files with 49 additions and 5 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,11 @@
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,5 @@
from pydantic import BaseModel
class VLLMConfig(BaseModel):
pass

View file

@ -0,0 +1,33 @@
import logging
from typing import Any
from llama_stack.apis.inference.inference import CompletionResponse, CompletionResponseStreamChunk, LogProbConfig, ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse
from llama_stack.apis.inference import Inference
from .config import VLLMConfig
from llama_models.llama3.api.datatypes import InterleavedTextMedia, Message, ToolChoice, ToolDefinition, ToolPromptFormat
log = logging.getLogger(__name__)
class VLLMInferenceImpl(Inference):
"""Inference implementation for vLLM."""
def __init__(self, config: VLLMConfig):
self.config = config
async def initialize(self):
log.info("Initializing vLLM inference adapter")
pass
async def completion(self, model: str, content: InterleavedTextMedia, sampling_params: Any | None = ..., stream: bool | None = False, logprobs: LogProbConfig | None = None) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
return None
async def chat_completion(self, model: str, messages: list[Message], sampling_params: Any | None = ..., tools: list[ToolDefinition] | None = ..., tool_choice: ToolChoice | None = ..., tool_prompt_format: ToolPromptFormat | None = ..., stream: bool | None = False, logprobs: LogProbConfig | None = None) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
return None
async def embeddings(self, model: str, contents: list[InterleavedTextMedia]) -> EmbeddingsResponse:
log.info("vLLM embeddings")