mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Fix import and model mapping
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
925e1afb5b
commit
765f2c86af
1 changed files with 12 additions and 5 deletions
|
@ -15,12 +15,16 @@ from llama_models.sku_list import resolve_model
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import augment_messages_for_tools
|
||||
|
||||
from .config import VLLMImplConfig
|
||||
|
||||
# TODO
|
||||
VLLM_SUPPORTED_MODELS = {}
|
||||
# Reference: https://docs.vllm.ai/en/latest/models/supported_models.html
|
||||
VLLM_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||
}
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference):
|
||||
|
@ -70,7 +74,10 @@ class VLLMInferenceAdapter(Inference):
|
|||
|
||||
def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
# TODO
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -99,7 +106,7 @@ class VLLMInferenceAdapter(Inference):
|
|||
# accumulate sampling params and other options to pass to vLLM
|
||||
options = self.get_vllm_chat_options(request)
|
||||
vllm_model = self.resolve_vllm_model(request.model)
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue