mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix import and model mapping
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
925e1afb5b
commit
765f2c86af
1 changed files with 12 additions and 5 deletions
|
@ -15,12 +15,16 @@ from llama_models.sku_list import resolve_model
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.augment_messages import augment_messages_for_tools
|
||||||
|
|
||||||
from .config import VLLMImplConfig
|
from .config import VLLMImplConfig
|
||||||
|
|
||||||
# TODO
|
# Reference: https://docs.vllm.ai/en/latest/models/supported_models.html
|
||||||
VLLM_SUPPORTED_MODELS = {}
|
VLLM_SUPPORTED_MODELS = {
|
||||||
|
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
|
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference):
|
class VLLMInferenceAdapter(Inference):
|
||||||
|
@ -70,7 +74,10 @@ class VLLMInferenceAdapter(Inference):
|
||||||
|
|
||||||
def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict:
|
def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
# TODO
|
if request.sampling_params is not None:
|
||||||
|
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||||
|
if getattr(request.sampling_params, attr):
|
||||||
|
options[attr] = getattr(request.sampling_params, attr)
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -99,7 +106,7 @@ class VLLMInferenceAdapter(Inference):
|
||||||
# accumulate sampling params and other options to pass to vLLM
|
# accumulate sampling params and other options to pass to vLLM
|
||||||
options = self.get_vllm_chat_options(request)
|
options = self.get_vllm_chat_options(request)
|
||||||
vllm_model = self.resolve_vllm_model(request.model)
|
vllm_model = self.resolve_vllm_model(request.model)
|
||||||
messages = prepare_messages(request)
|
messages = augment_messages_for_tools(request)
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
|
|
||||||
input_tokens = len(model_input.tokens)
|
input_tokens = len(model_input.tokens)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue