Address feedback

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2024-10-11 16:08:45 -04:00
parent 660983b72d
commit c085886cdb
No known key found for this signature in database

View file

@ -8,11 +8,11 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
@ -25,36 +25,54 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMImplConfig from .config import VLLMImplConfig
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
class VLLMInferenceAdapter(Inference):
model_id: str
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None: def __init__(self, config: VLLMImplConfig) -> None:
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
return self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for vLLM models")
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [ return [
ModelDef( ModelDef(identifier=model.id, llama_model=model.id)
identifier=identifier, for model in self.client.models.list()
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
] ]
def completion( def completion(
@ -88,12 +106,10 @@ class VLLMInferenceAdapter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, self.client)
else: else:
return self._nonstream_chat_completion(request, client) return self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: OpenAI
@ -122,7 +138,7 @@ class VLLMInferenceAdapter(Inference):
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
return { return {
"model": request.model, "model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request), **get_sampling_options(request),