mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
Merge f1638a3c30
into b82af5b826
This commit is contained in:
commit
495990b708
3 changed files with 160 additions and 7 deletions
|
@ -1,5 +1,6 @@
|
|||
import time # type: ignore
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -24,17 +25,31 @@ class VLLMError(Exception):
|
|||
|
||||
|
||||
# check if vllm is installed
|
||||
def validate_environment(model: str):
|
||||
def validate_environment(model: str, vllm_params: dict):
|
||||
global llm
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore
|
||||
|
||||
if llm is None:
|
||||
llm = LLM(model=model)
|
||||
llm = LLM(model=model, **vllm_params)
|
||||
return llm, SamplingParams
|
||||
except Exception as e:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
|
||||
# extract vllm params from optional params
|
||||
def handle_vllm_params(optional_params: Optional[dict]):
|
||||
vllm_params = litellm.VLLMConfig.get_config()
|
||||
if optional_params is None:
|
||||
optional_params = {}
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if k in vllm_params:
|
||||
vllm_params[k] = v
|
||||
|
||||
optional_params = {k: v for k, v in optional_params.items() if k not in vllm_params}
|
||||
|
||||
return vllm_params, optional_params
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
|
@ -49,8 +64,9 @@ def completion(
|
|||
logger_fn=None,
|
||||
):
|
||||
global llm
|
||||
vllm_params, optional_params = handle_vllm_params(optional_params)
|
||||
try:
|
||||
llm, SamplingParams = validate_environment(model=model)
|
||||
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
|
||||
except Exception as e:
|
||||
raise VLLMError(status_code=0, message=str(e))
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
|
@ -138,8 +154,9 @@ def batch_completions(
|
|||
]
|
||||
)
|
||||
"""
|
||||
vllm_params, optional_params = handle_vllm_params(optional_params)
|
||||
try:
|
||||
llm, SamplingParams = validate_environment(model=model)
|
||||
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
raise VLLMError(status_code=0, message=error_str)
|
||||
|
|
|
@ -4,6 +4,9 @@ Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
|||
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import types
|
||||
|
||||
from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
|
||||
|
||||
|
@ -11,5 +14,78 @@ class VLLMConfig(HostedVLLMChatConfig):
|
|||
"""
|
||||
VLLM SDK supports the same OpenAI params as hosted_vllm.
|
||||
"""
|
||||
|
||||
pass
|
||||
model: str
|
||||
tokenizer: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
trust_remote_code: bool = False
|
||||
allowed_local_media_path: str = ""
|
||||
tensor_parallel_size: int = 1
|
||||
dtype: str = "auto"
|
||||
quantization: Optional[str] = None
|
||||
load_format: str = "auto"
|
||||
revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
seed: int = 0
|
||||
gpu_memory_utilization: float = 0.9
|
||||
swap_space: float = 4
|
||||
cpu_offload_gb: float = 0
|
||||
enforce_eager: Optional[bool] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_async_output_proc: bool = False
|
||||
hf_overrides: Optional[Any] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
task: str = "auto"
|
||||
override_pooler_config: Optional[Any] = None
|
||||
compilation_config: Optional[Union[int, Dict[str, Any]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[Any] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
task: str = "auto",
|
||||
override_pooler_config: Optional[Any] = None,
|
||||
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
|
||||
):
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self":
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
60
tests/llm_translation/test_vllm.py
Normal file
60
tests/llm_translation/test_vllm.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
|
||||
def test_vllm():
|
||||
litellm.set_verbose = True
|
||||
|
||||
with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
|
||||
mock_client.return_value = MagicMock(), MagicMock()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
response = litellm.completion(
|
||||
model="vllm/facebook/opt-125m",
|
||||
messages=messages
|
||||
)
|
||||
|
||||
# Verify the request was made
|
||||
mock_client.assert_called_once()
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
assert request_body["model"] == "facebook/opt-125m"
|
||||
assert request_body["vllm_params"] is not None
|
||||
assert request_body["vllm_params"]["quantization"] is None
|
||||
|
||||
|
||||
def test_vllm_quantized():
|
||||
litellm.set_verbose = True
|
||||
|
||||
with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
|
||||
mock_client.return_value = MagicMock(), MagicMock()
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
response = litellm.completion(
|
||||
model="vllm/facebook/opt-125m",
|
||||
messages=messages,
|
||||
dtype="auto",
|
||||
quantization="bitsandbytes",
|
||||
load_format="bitsandbytes"
|
||||
)
|
||||
|
||||
# Verify the request was made
|
||||
mock_client.assert_called_once()
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
assert request_body["model"] == "facebook/opt-125m"
|
||||
assert request_body["vllm_params"] is not None
|
||||
assert request_body["vllm_params"]["quantization"] == "bitsandbytes"
|
||||
assert request_body["vllm_params"]["dtype"] == "auto"
|
||||
assert request_body["vllm_params"]["load_format"] == "bitsandbytes"
|
Loading…
Add table
Add a link
Reference in a new issue