diff --git a/litellm/llms/vllm/completion/handler.py b/litellm/llms/vllm/completion/handler.py index 1f13082917..e1fc41ff92 100644 --- a/litellm/llms/vllm/completion/handler.py +++ b/litellm/llms/vllm/completion/handler.py @@ -1,5 +1,6 @@ import time # type: ignore -from typing import Callable +from typing import Callable, Optional +import litellm import httpx @@ -24,17 +25,31 @@ class VLLMError(Exception): # check if vllm is installed -def validate_environment(model: str): +def validate_environment(model: str, vllm_params: dict): global llm try: from vllm import LLM, SamplingParams # type: ignore if llm is None: - llm = LLM(model=model) + llm = LLM(model=model, **vllm_params) return llm, SamplingParams except Exception as e: raise VLLMError(status_code=0, message=str(e)) +# extract vllm params from optional params +def handle_vllm_params(optional_params: Optional[dict]): + vllm_params = litellm.VLLMConfig.get_config() + if optional_params is None: + optional_params = {} + + for k, v in optional_params.items(): + if k in vllm_params: + vllm_params[k] = v + + optional_params = {k: v for k, v in optional_params.items() if k not in vllm_params} + + return vllm_params, optional_params + def completion( model: str, @@ -49,8 +64,9 @@ def completion( logger_fn=None, ): global llm + vllm_params, optional_params = handle_vllm_params(optional_params) try: - llm, SamplingParams = validate_environment(model=model) + llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params) except Exception as e: raise VLLMError(status_code=0, message=str(e)) sampling_params = SamplingParams(**optional_params) @@ -138,8 +154,9 @@ def batch_completions( ] ) """ + vllm_params, optional_params = handle_vllm_params(optional_params) try: - llm, SamplingParams = validate_environment(model=model) + llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params) except Exception as e: error_str = str(e) raise VLLMError(status_code=0, message=error_str) diff --git a/litellm/llms/vllm/completion/transformation.py b/litellm/llms/vllm/completion/transformation.py index ec4c07e95d..5fa67e4ae3 100644 --- a/litellm/llms/vllm/completion/transformation.py +++ b/litellm/llms/vllm/completion/transformation.py @@ -4,6 +4,9 @@ Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead. """ +from typing import Optional, Dict, Any, Union +import types + from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig @@ -11,5 +14,78 @@ class VLLMConfig(HostedVLLMChatConfig): """ VLLM SDK supports the same OpenAI params as hosted_vllm. """ - - pass + model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = "auto" + skip_tokenizer_init: bool = False + trust_remote_code: bool = False + allowed_local_media_path: str = "" + tensor_parallel_size: int = 1 + dtype: str = "auto" + quantization: Optional[str] = None + load_format: str = "auto" + revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + seed: int = 0 + gpu_memory_utilization: float = 0.9 + swap_space: float = 4 + cpu_offload_gb: float = 0 + enforce_eager: Optional[bool] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + disable_async_output_proc: bool = False + hf_overrides: Optional[Any] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + task: str = "auto" + override_pooler_config: Optional[Any] = None + compilation_config: Optional[Union[int, Dict[str, Any]]] = None + + def __init__( + self, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[Any] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + task: str = "auto", + override_pooler_config: Optional[Any] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + ): + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self": + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not k.startswith("_abc") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + } diff --git a/tests/llm_translation/test_vllm.py b/tests/llm_translation/test_vllm.py new file mode 100644 index 0000000000..b87719ef70 --- /dev/null +++ b/tests/llm_translation/test_vllm.py @@ -0,0 +1,60 @@ +import pytest +from unittest.mock import MagicMock, patch + +import litellm + +def test_vllm(): + litellm.set_verbose = True + + with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client: + mock_client.return_value = MagicMock(), MagicMock() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + response = litellm.completion( + model="vllm/facebook/opt-125m", + messages=messages + ) + + # Verify the request was made + mock_client.assert_called_once() + + # Check the request body + request_body = mock_client.call_args.kwargs + + assert request_body["model"] == "facebook/opt-125m" + assert request_body["vllm_params"] is not None + assert request_body["vllm_params"]["quantization"] is None + + +def test_vllm_quantized(): + litellm.set_verbose = True + + with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client: + mock_client.return_value = MagicMock(), MagicMock() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + response = litellm.completion( + model="vllm/facebook/opt-125m", + messages=messages, + dtype="auto", + quantization="bitsandbytes", + load_format="bitsandbytes" + ) + + # Verify the request was made + mock_client.assert_called_once() + + # Check the request body + request_body = mock_client.call_args.kwargs + + assert request_body["model"] == "facebook/opt-125m" + assert request_body["vllm_params"] is not None + assert request_body["vllm_params"]["quantization"] == "bitsandbytes" + assert request_body["vllm_params"]["dtype"] == "auto" + assert request_body["vllm_params"]["load_format"] == "bitsandbytes" \ No newline at end of file