From 7bbce6394a073b7f730ce008020b07de0b5e6ac0 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 10 Oct 2024 20:58:52 -0400 Subject: [PATCH] Working --- .../providers/adapters/inference/vllm/config.py | 1 - .../providers/adapters/inference/vllm/vllm.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/adapters/inference/vllm/config.py b/llama_stack/providers/adapters/inference/vllm/config.py index cec57c814..d1cfa6dde 100644 --- a/llama_stack/providers/adapters/inference/vllm/config.py +++ b/llama_stack/providers/adapters/inference/vllm/config.py @@ -10,7 +10,6 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field -# TODO: Any other engine configs @json_schema_type class VLLMImplConfig(BaseModel): url: Optional[str] = Field( diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index 1e2799a51..5ddcefe89 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -29,7 +29,8 @@ from .config import VLLMImplConfig # Reference: https://docs.vllm.ai/en/latest/models/supported_models.html VLLM_SUPPORTED_MODELS = { - "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3-70B-Instruct", + "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", } @@ -48,7 +49,14 @@ class VLLMInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - def completion(self, request: CompletionRequest) -> AsyncGenerator: + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: raise NotImplementedError() def chat_completion(