This commit is contained in:
Ivan Vykopal 2025-04-24 01:02:22 -07:00 committed by GitHub
commit 495990b708
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 160 additions and 7 deletions

View file

@ -1,5 +1,6 @@
import time # type: ignore
from typing import Callable
from typing import Callable, Optional
import litellm
import httpx
@ -24,17 +25,31 @@ class VLLMError(Exception):
# check if vllm is installed
def validate_environment(model: str):
def validate_environment(model: str, vllm_params: dict):
global llm
try:
from vllm import LLM, SamplingParams # type: ignore
if llm is None:
llm = LLM(model=model)
llm = LLM(model=model, **vllm_params)
return llm, SamplingParams
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
# extract vllm params from optional params
def handle_vllm_params(optional_params: Optional[dict]):
vllm_params = litellm.VLLMConfig.get_config()
if optional_params is None:
optional_params = {}
for k, v in optional_params.items():
if k in vllm_params:
vllm_params[k] = v
optional_params = {k: v for k, v in optional_params.items() if k not in vllm_params}
return vllm_params, optional_params
def completion(
model: str,
@ -49,8 +64,9 @@ def completion(
logger_fn=None,
):
global llm
vllm_params, optional_params = handle_vllm_params(optional_params)
try:
llm, SamplingParams = validate_environment(model=model)
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
@ -138,8 +154,9 @@ def batch_completions(
]
)
"""
vllm_params, optional_params = handle_vllm_params(optional_params)
try:
llm, SamplingParams = validate_environment(model=model)
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
except Exception as e:
error_str = str(e)
raise VLLMError(status_code=0, message=error_str)

View file

@ -4,6 +4,9 @@ Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
"""
from typing import Optional, Dict, Any, Union
import types
from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig
@ -11,5 +14,78 @@ class VLLMConfig(HostedVLLMChatConfig):
"""
VLLM SDK supports the same OpenAI params as hosted_vllm.
"""
pass
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
trust_remote_code: bool = False
allowed_local_media_path: str = ""
tensor_parallel_size: int = 1
dtype: str = "auto"
quantization: Optional[str] = None
load_format: str = "auto"
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
seed: int = 0
gpu_memory_utilization: float = 0.9
swap_space: float = 4
cpu_offload_gb: float = 0
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
disable_async_output_proc: bool = False
hf_overrides: Optional[Any] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
task: str = "auto"
override_pooler_config: Optional[Any] = None
compilation_config: Optional[Union[int, Dict[str, Any]]] = None
def __init__(
self,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[Any] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
task: str = "auto",
override_pooler_config: Optional[Any] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
):
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self":
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
}

View file

@ -0,0 +1,60 @@
import pytest
from unittest.mock import MagicMock, patch
import litellm
def test_vllm():
litellm.set_verbose = True
with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
mock_client.return_value = MagicMock(), MagicMock()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
response = litellm.completion(
model="vllm/facebook/opt-125m",
messages=messages
)
# Verify the request was made
mock_client.assert_called_once()
# Check the request body
request_body = mock_client.call_args.kwargs
assert request_body["model"] == "facebook/opt-125m"
assert request_body["vllm_params"] is not None
assert request_body["vllm_params"]["quantization"] is None
def test_vllm_quantized():
litellm.set_verbose = True
with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
mock_client.return_value = MagicMock(), MagicMock()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
response = litellm.completion(
model="vllm/facebook/opt-125m",
messages=messages,
dtype="auto",
quantization="bitsandbytes",
load_format="bitsandbytes"
)
# Verify the request was made
mock_client.assert_called_once()
# Check the request body
request_body = mock_client.call_args.kwargs
assert request_body["model"] == "facebook/opt-125m"
assert request_body["vllm_params"] is not None
assert request_body["vllm_params"]["quantization"] == "bitsandbytes"
assert request_body["vllm_params"]["dtype"] == "auto"
assert request_body["vllm_params"]["load_format"] == "bitsandbytes"