From b81a3bd46a2cffb6e79e6386b0eb9eec618f6dfa Mon Sep 17 00:00:00 2001 From: Suraj Subramanian <5676233+subramen@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:56:00 -0400 Subject: [PATCH] Fix import conflict for SamplingParams (#285) Conflict between llama_models.llama3.api.datatypes.SamplingParams and vllm.sampling_params.SamplingParams results in errors while processing VLLM engine requests --- llama_stack/providers/impls/vllm/vllm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index c977c738d..ad3ad8fb7 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -7,7 +7,7 @@ import logging import os import uuid -from typing import Any +from typing import Any, AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams as VLLMSamplingParams from llama_stack.apis.inference import * # noqa: F403 @@ -40,10 +40,10 @@ def _random_uuid() -> str: return str(uuid.uuid4().hex) -def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: +def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams: """Convert sampling params to vLLM sampling params.""" if sampling_params is None: - return SamplingParams() + return VLLMSamplingParams() # TODO convert what I saw in my first test ... but surely there's more to do here kwargs = { @@ -58,7 +58,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: if sampling_params.repetition_penalty > 0: kwargs["repetition_penalty"] = sampling_params.repetition_penalty - return SamplingParams(**kwargs) + return VLLMSamplingParams(**kwargs) class VLLMInferenceImpl(ModelRegistryHelper, Inference):