Fix import conflict for SamplingParams (#285)

Conflict between llama_models.llama3.api.datatypes.SamplingParams and vllm.sampling_params.SamplingParams results in errors while processing VLLM engine requests
This commit is contained in:
Suraj Subramanian 2024-10-22 15:56:00 -04:00 committed by GitHub
parent c06718fbd5
commit b81a3bd46a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,7 +7,7 @@
import logging import logging
import os import os
import uuid import uuid
from typing import Any from typing import Any, AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -40,10 +40,10 @@ def _random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams: def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams:
"""Convert sampling params to vLLM sampling params.""" """Convert sampling params to vLLM sampling params."""
if sampling_params is None: if sampling_params is None:
return SamplingParams() return VLLMSamplingParams()
# TODO convert what I saw in my first test ... but surely there's more to do here # TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = { kwargs = {
@ -58,7 +58,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
if sampling_params.repetition_penalty > 0: if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return SamplingParams(**kwargs) return VLLMSamplingParams(**kwargs)
class VLLMInferenceImpl(ModelRegistryHelper, Inference): class VLLMInferenceImpl(ModelRegistryHelper, Inference):