mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Litellm vllm refactor (#7158)
* refactor(vllm/): move vllm to use base llm config * test: mark flaky test
This commit is contained in:
parent
e9fbefca5d
commit
cd9b92b402
9 changed files with 48 additions and 8 deletions
|
@ -1184,6 +1184,7 @@ from .llms.azure.azure import (
|
||||||
|
|
||||||
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
||||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
|
from .llms.vllm.completion.transformation import VLLMConfig
|
||||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||||
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
||||||
|
|
|
@ -5,7 +5,7 @@ import litellm
|
||||||
from litellm._logging import print_verbose
|
from litellm._logging import print_verbose
|
||||||
from litellm.utils import get_optional_params
|
from litellm.utils import get_optional_params
|
||||||
|
|
||||||
from ..llms import vllm
|
from ..llms.vllm.completion import handler as vllm_handler
|
||||||
|
|
||||||
|
|
||||||
def batch_completion(
|
def batch_completion(
|
||||||
|
@ -83,7 +83,7 @@ def batch_completion(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
results = vllm.batch_completions(
|
results = vllm_handler.batch_completions(
|
||||||
model=model,
|
model=model,
|
||||||
messages=batch_messages,
|
messages=batch_messages,
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
|
|
@ -29,7 +29,7 @@ LITELLM_CHAT_PROVIDERS = [
|
||||||
# "sagemaker",
|
# "sagemaker",
|
||||||
# "sagemaker_chat",
|
# "sagemaker_chat",
|
||||||
# "bedrock",
|
# "bedrock",
|
||||||
# "vllm",
|
"vllm",
|
||||||
# "nlp_cloud",
|
# "nlp_cloud",
|
||||||
# "petals",
|
# "petals",
|
||||||
# "oobabooga",
|
# "oobabooga",
|
||||||
|
|
|
@ -58,6 +58,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "hosted_vllm":
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "vllm":
|
||||||
|
return litellm.VLLMConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
return [
|
return [
|
||||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||||
|
|
|
@ -9,7 +9,7 @@ import requests # type: ignore
|
||||||
|
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
llm = None
|
llm = None
|
||||||
|
|
||||||
|
@ -29,7 +29,8 @@ class VLLMError(Exception):
|
||||||
def validate_environment(model: str):
|
def validate_environment(model: str):
|
||||||
global llm
|
global llm
|
||||||
try:
|
try:
|
||||||
from vllm import LLM, SamplingParams # type: ignore
|
from litellm.llms.vllm.completion.handler import LLM # type: ignore
|
||||||
|
from litellm.llms.vllm.completion.handler import SamplingParams # type: ignore
|
||||||
|
|
||||||
if llm is None:
|
if llm is None:
|
||||||
llm = LLM(model=model)
|
llm = LLM(model=model)
|
19
litellm/llms/vllm/completion/transformation.py
Normal file
19
litellm/llms/vllm/completion/transformation.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/chat/completions` to the VLLM sdk `llm.generate`.
|
||||||
|
|
||||||
|
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
|
||||||
|
from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMConfig(HostedVLLMChatConfig):
|
||||||
|
"""
|
||||||
|
VLLM SDK supports the same OpenAI params as hosted_vllm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
|
@ -94,7 +94,6 @@ from .llms import (
|
||||||
palm,
|
palm,
|
||||||
petals,
|
petals,
|
||||||
replicate,
|
replicate,
|
||||||
vllm,
|
|
||||||
)
|
)
|
||||||
from .llms.ai21 import completion as ai21
|
from .llms.ai21 import completion as ai21
|
||||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||||
|
@ -160,6 +159,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
||||||
VertexAIModelGardenModels,
|
VertexAIModelGardenModels,
|
||||||
)
|
)
|
||||||
|
from .llms.vllm.completion import handler
|
||||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||||
from .types.llms.openai import (
|
from .types.llms.openai import (
|
||||||
|
@ -2691,7 +2691,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
response = response
|
response = response
|
||||||
elif custom_llm_provider == "vllm":
|
elif custom_llm_provider == "vllm":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
model_response = vllm.completion(
|
model_response = handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
|
|
@ -3541,7 +3541,21 @@ def get_optional_params( # noqa: PLR0915
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vllm":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.VLLMConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
)
|
||||||
elif custom_llm_provider == "groq":
|
elif custom_llm_provider == "groq":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6319,6 +6333,8 @@ class ProviderConfigManager:
|
||||||
return litellm.LMStudioChatConfig()
|
return litellm.LMStudioChatConfig()
|
||||||
elif litellm.LlmProviders.GALADRIEL == provider:
|
elif litellm.LlmProviders.GALADRIEL == provider:
|
||||||
return litellm.GaladrielChatConfig()
|
return litellm.GaladrielChatConfig()
|
||||||
|
elif litellm.LlmProviders.VLLM == provider:
|
||||||
|
return litellm.VLLMConfig()
|
||||||
elif litellm.LlmProviders.OLLAMA == provider:
|
elif litellm.LlmProviders.OLLAMA == provider:
|
||||||
return litellm.OllamaConfig()
|
return litellm.OllamaConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
|
@ -513,6 +513,7 @@ async def test_get_current_provider_spend():
|
||||||
assert spend == 50.5
|
assert spend == 50.5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=6, delay=2)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_current_provider_budget_reset_at():
|
async def test_get_current_provider_budget_reset_at():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue