Fix import and model mapping

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2024-10-03 17:57:21 -04:00
parent 925e1afb5b
commit 765f2c86af
No known key found for this signature in database

View file

@ -15,12 +15,16 @@ from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import augment_messages_for_tools
from .config import VLLMImplConfig
# TODO
VLLM_SUPPORTED_MODELS = {}
# Reference: https://docs.vllm.ai/en/latest/models/supported_models.html
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
}
class VLLMInferenceAdapter(Inference):
@ -70,7 +74,10 @@ class VLLMInferenceAdapter(Inference):
def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
# TODO
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
@ -99,7 +106,7 @@ class VLLMInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to vLLM
options = self.get_vllm_chat_options(request)
vllm_model = self.resolve_vllm_model(request.model)
messages = prepare_messages(request)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
input_tokens = len(model_input.tokens)