Litellm vllm refactor (#7158)

* refactor(vllm/): move vllm to use base llm config

* test: mark flaky test
This commit is contained in:
Krish Dholakia 2024-12-10 21:48:35 -08:00 committed by GitHub
parent e9fbefca5d
commit cd9b92b402
9 changed files with 48 additions and 8 deletions

View file

@ -1184,6 +1184,7 @@ from .llms.azure.azure import (
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.vllm.completion.transformation import VLLMConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig from .llms.lm_studio.chat.transformation import LMStudioChatConfig
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig

View file

@ -5,7 +5,7 @@ import litellm
from litellm._logging import print_verbose from litellm._logging import print_verbose
from litellm.utils import get_optional_params from litellm.utils import get_optional_params
from ..llms import vllm from ..llms.vllm.completion import handler as vllm_handler
def batch_completion( def batch_completion(
@ -83,7 +83,7 @@ def batch_completion(
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
results = vllm.batch_completions( results = vllm_handler.batch_completions(
model=model, model=model,
messages=batch_messages, messages=batch_messages,
custom_prompt_dict=litellm.custom_prompt_dict, custom_prompt_dict=litellm.custom_prompt_dict,

View file

@ -29,7 +29,7 @@ LITELLM_CHAT_PROVIDERS = [
# "sagemaker", # "sagemaker",
# "sagemaker_chat", # "sagemaker_chat",
# "bedrock", # "bedrock",
# "vllm", "vllm",
# "nlp_cloud", # "nlp_cloud",
# "petals", # "petals",
# "oobabooga", # "oobabooga",

View file

@ -58,6 +58,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.GroqChatConfig().get_supported_openai_params(model=model) return litellm.GroqChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "hosted_vllm": elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model) return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "vllm":
return litellm.VLLMConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
return [ return [
# https://platform.deepseek.com/api-docs/api/create-chat-completion # https://platform.deepseek.com/api-docs/api/create-chat-completion

View file

@ -9,7 +9,7 @@ import requests # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory from ...prompt_templates.factory import custom_prompt, prompt_factory
llm = None llm = None
@ -29,7 +29,8 @@ class VLLMError(Exception):
def validate_environment(model: str): def validate_environment(model: str):
global llm global llm
try: try:
from vllm import LLM, SamplingParams # type: ignore from litellm.llms.vllm.completion.handler import LLM # type: ignore
from litellm.llms.vllm.completion.handler import SamplingParams # type: ignore
if llm is None: if llm is None:
llm = LLM(model=model) llm = LLM(model=model)

View file

@ -0,0 +1,19 @@
"""
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
"""
from typing import List
from litellm.types.llms.openai import AllMessageValues
from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig
class VLLMConfig(HostedVLLMChatConfig):
"""
VLLM SDK supports the same OpenAI params as hosted_vllm.
"""
pass

View file

@ -94,7 +94,6 @@ from .llms import (
palm, palm,
petals, petals,
replicate, replicate,
vllm,
) )
from .llms.ai21 import completion as ai21 from .llms.ai21 import completion as ai21
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
@ -160,6 +159,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
VertexAIModelGardenModels, VertexAIModelGardenModels,
) )
from .llms.vllm.completion import handler
from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import ( from .types.llms.openai import (
@ -2691,7 +2691,7 @@ def completion( # type: ignore # noqa: PLR0915
response = response response = response
elif custom_llm_provider == "vllm": elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = vllm.completion( model_response = handler.completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,

View file

@ -3541,7 +3541,21 @@ def get_optional_params( # noqa: PLR0915
else False else False
), ),
) )
elif custom_llm_provider == "vllm":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VLLMConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -6319,6 +6333,8 @@ class ProviderConfigManager:
return litellm.LMStudioChatConfig() return litellm.LMStudioChatConfig()
elif litellm.LlmProviders.GALADRIEL == provider: elif litellm.LlmProviders.GALADRIEL == provider:
return litellm.GaladrielChatConfig() return litellm.GaladrielChatConfig()
elif litellm.LlmProviders.VLLM == provider:
return litellm.VLLMConfig()
elif litellm.LlmProviders.OLLAMA == provider: elif litellm.LlmProviders.OLLAMA == provider:
return litellm.OllamaConfig() return litellm.OllamaConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()

View file

@ -513,6 +513,7 @@ async def test_get_current_provider_spend():
assert spend == 50.5 assert spend == 50.5
@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_provider_budget_reset_at(): async def test_get_current_provider_budget_reset_at():
""" """