fix linting issues

This commit is contained in:
Krrish Dholakia 2023-09-06 20:43:56 -07:00
parent 01313cd1f0
commit 6b3cb18983
2 changed files with 21 additions and 6 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable, Any
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
llm = None llm = None
@ -16,7 +16,7 @@ class VLLMError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# check if vllm is installed # check if vllm is installed
def validate_environment(model: str, llm: any=None): def validate_environment(model: str, llm: Any =None):
try: try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
if llm is None: if llm is None:
@ -37,6 +37,7 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
global llm
try: try:
llm, SamplingParams = validate_environment(model=model) llm, SamplingParams = validate_environment(model=model)
except Exception as e: except Exception as e:
@ -62,7 +63,10 @@ def completion(
additional_args={"complete_input_dict": sampling_params}, additional_args={"complete_input_dict": sampling_params},
) )
if llm:
outputs = llm.generate(prompt, sampling_params) outputs = llm.generate(prompt, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
## COMPLETION CALL ## COMPLETION CALL
@ -128,10 +132,11 @@ def batch_completions(
try: try:
llm, SamplingParams = validate_environment(model=model, llm=llm) llm, SamplingParams = validate_environment(model=model, llm=llm)
except Exception as e: except Exception as e:
if "data parallel group is already initialized" in e: error_str = str(e)
if "data parallel group is already initialized" in error_str:
pass pass
else: else:
raise VLLMError(status_code=0, message=str(e)) raise VLLMError(status_code=0, message=error_str)
sampling_params = SamplingParams(**optional_params) sampling_params = SamplingParams(**optional_params)
prompts = [] prompts = []
if model in custom_prompt_dict: if model in custom_prompt_dict:
@ -150,7 +155,10 @@ def batch_completions(
prompt = prompt_factory(model=model, messages=message) prompt = prompt_factory(model=model, messages=message)
prompts.append(prompt) prompts.append(prompt)
if llm:
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
else:
raise VLLMError(status_code=0, message="Need to pass in a model name to initialize vllm")
final_outputs = [] final_outputs = []
for output in outputs: for output in outputs:

View file

@ -844,6 +844,13 @@ def batch_completion(
frequency_penalty=0, frequency_penalty=0,
logit_bias: dict = {}, logit_bias: dict = {},
user: str = "", user: str = "",
# Optional liteLLM function params
*,
return_async=False,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
api_base: Optional[str] = None,
force_timeout=600,
# used by text-bison only # used by text-bison only
top_k=40, top_k=40,
custom_llm_provider=None,): custom_llm_provider=None,):