diff --git a/litellm/llms/vllm.py b/litellm/llms/vllm.py index 0f446ae81..bc46e77d0 100644 --- a/litellm/llms/vllm.py +++ b/litellm/llms/vllm.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests import time -from typing import Callable +from typing import Callable, Any from litellm.utils import ModelResponse from .prompt_templates.factory import prompt_factory, custom_prompt llm = None @@ -16,7 +16,7 @@ class VLLMError(Exception): ) # Call the base class constructor with the parameters it needs # check if vllm is installed -def validate_environment(model: str, llm: any=None): +def validate_environment(model: str, llm: Any =None): try: from vllm import LLM, SamplingParams if llm is None: @@ -37,6 +37,7 @@ def completion( litellm_params=None, logger_fn=None, ): + global llm try: llm, SamplingParams = validate_environment(model=model) except Exception as e: @@ -62,7 +63,10 @@ def completion( additional_args={"complete_input_dict": sampling_params}, ) - outputs = llm.generate(prompt, sampling_params) + if llm: + outputs = llm.generate(prompt, sampling_params) + else: + raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") ## COMPLETION CALL @@ -128,10 +132,11 @@ def batch_completions( try: llm, SamplingParams = validate_environment(model=model, llm=llm) except Exception as e: - if "data parallel group is already initialized" in e: + error_str = str(e) + if "data parallel group is already initialized" in error_str: pass else: - raise VLLMError(status_code=0, message=str(e)) + raise VLLMError(status_code=0, message=error_str) sampling_params = SamplingParams(**optional_params) prompts = [] if model in custom_prompt_dict: @@ -150,7 +155,10 @@ def batch_completions( prompt = prompt_factory(model=model, messages=message) prompts.append(prompt) - outputs = llm.generate(prompts, sampling_params) + if llm: + outputs = llm.generate(prompts, sampling_params) + else: + raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm") final_outputs = [] for output in outputs: diff --git a/litellm/main.py b/litellm/main.py index 2e480b2dc..395cae31b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -844,6 +844,13 @@ def batch_completion( frequency_penalty=0, logit_bias: dict = {}, user: str = "", + # Optional liteLLM function params + *, + return_async=False, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + api_base: Optional[str] = None, + force_timeout=600, # used by text-bison only top_k=40, custom_llm_provider=None,):