From a78db3d5dd9e55b56c4b87a5e431793e3f38f185 Mon Sep 17 00:00:00 2001 From: Ivan Vykopal Date: Wed, 18 Dec 2024 19:57:48 +0000 Subject: [PATCH 1/3] feat: support vllm quantization --- litellm/llms/vllm/completion/handler.py | 29 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/litellm/llms/vllm/completion/handler.py b/litellm/llms/vllm/completion/handler.py index a64ed8974a..51e80be18e 100644 --- a/litellm/llms/vllm/completion/handler.py +++ b/litellm/llms/vllm/completion/handler.py @@ -2,7 +2,7 @@ import json import os import time # type: ignore from enum import Enum -from typing import Any, Callable +from typing import Any, Callable, Union import httpx @@ -27,14 +27,31 @@ class VLLMError(Exception): # check if vllm is installed -def validate_environment(model: str): +def validate_environment(model: str, optional_params: Union[dict, None]): global llm try: from vllm import LLM, SamplingParams # type: ignore + + default_params = { + "tokenizer": None, + "tokenizer_mode": "auto", + "skip_tokenizer_init": False, + "trust_remote_code": False, + "dtype": "auto", + "quantization": None, + "gpu_memory_utilization": 0.9, + "load_format": "auto", + } + + if optional_params is None: + optional_params = {} + + params = {**default_params, **{k: v for k, v in optional_params.items() if k in default_params}} + optional_params = {k: v for k, v in optional_params.items() if k not in default_params} if llm is None: - llm = LLM(model=model) - return llm, SamplingParams + llm = LLM(model=model, **params) + return llm, SamplingParams, optional_params except Exception as e: raise VLLMError(status_code=0, message=str(e)) @@ -53,7 +70,7 @@ def completion( ): global llm try: - llm, SamplingParams = validate_environment(model=model) + llm, SamplingParams, optional_params = validate_environment(model=model, optional_params=optional_params) except Exception as e: raise VLLMError(status_code=0, message=str(e)) sampling_params = SamplingParams(**optional_params) @@ -142,7 +159,7 @@ def batch_completions( ) """ try: - llm, SamplingParams = validate_environment(model=model) + llm, SamplingParams, optional_params = validate_environment(model=model, optional_params=optional_params) except Exception as e: error_str = str(e) raise VLLMError(status_code=0, message=error_str) From e4fe7e8f8279302ed8a9dff8a555ff413deebd94 Mon Sep 17 00:00:00 2001 From: Ivan Vykopal Date: Thu, 19 Dec 2024 19:47:34 +0100 Subject: [PATCH 2/3] feat: move default params to VLLMConfig --- litellm/llms/vllm/completion/handler.py | 47 ++++++----- .../llms/vllm/completion/transformation.py | 80 ++++++++++++++++++- 2 files changed, 100 insertions(+), 27 deletions(-) diff --git a/litellm/llms/vllm/completion/handler.py b/litellm/llms/vllm/completion/handler.py index 51e80be18e..21127f18ca 100644 --- a/litellm/llms/vllm/completion/handler.py +++ b/litellm/llms/vllm/completion/handler.py @@ -1,8 +1,8 @@ import json import os import time # type: ignore -from enum import Enum -from typing import Any, Callable, Union +from typing import Callable, Optional +import litellm import httpx @@ -27,34 +27,31 @@ class VLLMError(Exception): # check if vllm is installed -def validate_environment(model: str, optional_params: Union[dict, None]): +def validate_environment(model: str, vllm_params: dict): global llm try: from vllm import LLM, SamplingParams # type: ignore - - default_params = { - "tokenizer": None, - "tokenizer_mode": "auto", - "skip_tokenizer_init": False, - "trust_remote_code": False, - "dtype": "auto", - "quantization": None, - "gpu_memory_utilization": 0.9, - "load_format": "auto", - } - - if optional_params is None: - optional_params = {} - - params = {**default_params, **{k: v for k, v in optional_params.items() if k in default_params}} - optional_params = {k: v for k, v in optional_params.items() if k not in default_params} if llm is None: - llm = LLM(model=model, **params) - return llm, SamplingParams, optional_params + llm = LLM(model=model, **vllm_params) + return llm, SamplingParams except Exception as e: raise VLLMError(status_code=0, message=str(e)) +# extract vllm params from optional params +def handle_vllm_params(optional_params: Optional[dict]): + vllm_params = litellm.VLLMConfig.get_config() + if optional_params is None: + optional_params = {} + + for k, v in optional_params.items(): + if k in vllm_params: + vllm_params[k] = v + + optional_params = {k: v for k, v in optional_params.items() if k not in vllm_params} + + return vllm_params, optional_params + def completion( model: str, @@ -69,8 +66,9 @@ def completion( logger_fn=None, ): global llm + vllm_params, optional_params = handle_vllm_params(optional_params) try: - llm, SamplingParams, optional_params = validate_environment(model=model, optional_params=optional_params) + llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params) except Exception as e: raise VLLMError(status_code=0, message=str(e)) sampling_params = SamplingParams(**optional_params) @@ -158,8 +156,9 @@ def batch_completions( ] ) """ + vllm_params, optional_params = handle_vllm_params(optional_params) try: - llm, SamplingParams, optional_params = validate_environment(model=model, optional_params=optional_params) + llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params) except Exception as e: error_str = str(e) raise VLLMError(status_code=0, message=error_str) diff --git a/litellm/llms/vllm/completion/transformation.py b/litellm/llms/vllm/completion/transformation.py index 022812b769..27296ec04e 100644 --- a/litellm/llms/vllm/completion/transformation.py +++ b/litellm/llms/vllm/completion/transformation.py @@ -4,7 +4,8 @@ Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`. NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead. """ -from typing import List +from typing import List, Optional, Dict, Any, Union +import types from litellm.types.llms.openai import AllMessageValues @@ -15,5 +16,78 @@ class VLLMConfig(HostedVLLMChatConfig): """ VLLM SDK supports the same OpenAI params as hosted_vllm. """ - - pass + model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = "auto" + skip_tokenizer_init: bool = False + trust_remote_code: bool = False + allowed_local_media_path: str = "" + tensor_parallel_size: int = 1 + dtype: str = "auto" + quantization: Optional[str] = None + load_format: str = "auto" + revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + seed: int = 0 + gpu_memory_utilization: float = 0.9 + swap_space: float = 4 + cpu_offload_gb: float = 0 + enforce_eager: Optional[bool] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + disable_async_output_proc: bool = False + hf_overrides: Optional[Any] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + task: str = "auto" + override_pooler_config: Optional[Any] = None + compilation_config: Optional[Union[int, Dict[str, Any]]] = None + + def __init__( + self, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[Any] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + task: str = "auto", + override_pooler_config: Optional[Any] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + ): + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self": + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not k.startswith("_abc") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + } From 80494808830df10d6cdbe5fe76176b794e7d680c Mon Sep 17 00:00:00 2001 From: Ivan Vykopal Date: Thu, 19 Dec 2024 20:21:27 +0100 Subject: [PATCH 3/3] feat: add tests for vllm --- tests/llm_translation/test_vllm.py | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/llm_translation/test_vllm.py diff --git a/tests/llm_translation/test_vllm.py b/tests/llm_translation/test_vllm.py new file mode 100644 index 0000000000..b87719ef70 --- /dev/null +++ b/tests/llm_translation/test_vllm.py @@ -0,0 +1,60 @@ +import pytest +from unittest.mock import MagicMock, patch + +import litellm + +def test_vllm(): + litellm.set_verbose = True + + with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client: + mock_client.return_value = MagicMock(), MagicMock() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + response = litellm.completion( + model="vllm/facebook/opt-125m", + messages=messages + ) + + # Verify the request was made + mock_client.assert_called_once() + + # Check the request body + request_body = mock_client.call_args.kwargs + + assert request_body["model"] == "facebook/opt-125m" + assert request_body["vllm_params"] is not None + assert request_body["vllm_params"]["quantization"] is None + + +def test_vllm_quantized(): + litellm.set_verbose = True + + with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client: + mock_client.return_value = MagicMock(), MagicMock() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + response = litellm.completion( + model="vllm/facebook/opt-125m", + messages=messages, + dtype="auto", + quantization="bitsandbytes", + load_format="bitsandbytes" + ) + + # Verify the request was made + mock_client.assert_called_once() + + # Check the request body + request_body = mock_client.call_args.kwargs + + assert request_body["model"] == "facebook/opt-125m" + assert request_body["vllm_params"] is not None + assert request_body["vllm_params"]["quantization"] == "bitsandbytes" + assert request_body["vllm_params"]["dtype"] == "auto" + assert request_body["vllm_params"]["load_format"] == "bitsandbytes" \ No newline at end of file