mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Fix import conflict for SamplingParams (#285)
Conflict between llama_models.llama3.api.datatypes.SamplingParams and vllm.sampling_params.SamplingParams results in errors while processing VLLM engine requests
This commit is contained in:
parent
c06718fbd5
commit
b81a3bd46a
1 changed files with 5 additions and 5 deletions
|
@ -7,7 +7,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
@ -40,10 +40,10 @@ def _random_uuid() -> str:
|
||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams:
|
||||||
"""Convert sampling params to vLLM sampling params."""
|
"""Convert sampling params to vLLM sampling params."""
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
return SamplingParams()
|
return VLLMSamplingParams()
|
||||||
|
|
||||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -58,7 +58,7 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
|
||||||
if sampling_params.repetition_penalty > 0:
|
if sampling_params.repetition_penalty > 0:
|
||||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
return SamplingParams(**kwargs)
|
return VLLMSamplingParams(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue