style: fix linting errors

This commit is contained in:
Krrish Dholakia 2023-10-16 17:35:08 -07:00
parent 5ef054f105
commit 7572086231
3 changed files with 7 additions and 7 deletions

View file

@ -5,7 +5,7 @@ from enum import Enum
import requests import requests
import time import time
import litellm import litellm
from typing import Callable from typing import Callable, Dict, List, Any
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -265,7 +265,7 @@ def completion(
content = "" content = ""
for chunk in streamed_response: for chunk in streamed_response:
content += chunk["choices"][0]["delta"]["content"] content += chunk["choices"][0]["delta"]["content"]
completion_response = [{"generated_text": content}] completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input_text, input=input_text,
@ -298,10 +298,10 @@ def completion(
) )
else: else:
if task == "conversational": if task == "conversational":
if len(completion_response["generated_text"]) > 0: if len(completion_response["generated_text"]) > 0: # type: ignore
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
] = completion_response["generated_text"] ] = completion_response["generated_text"] # type: ignore
elif task == "text-generation-inference": elif task == "text-generation-inference":
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
@ -360,7 +360,7 @@ def embedding(
model_response=None, model_response=None,
encoding=None, encoding=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}") # print_verbose(f"{model}, {task}")
embed_url = "" embed_url = ""
if "https" in model: if "https" in model:

View file

@ -130,7 +130,7 @@ def completion(
try: try:
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM from petals import AutoDistributedModelForCausalLM # type: ignore
except: except:
raise Exception( raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals" "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"

View file

@ -18,7 +18,7 @@ class VLLMError(Exception):
# check if vllm is installed # check if vllm is installed
def validate_environment(model: str, llm: Any =None): def validate_environment(model: str, llm: Any =None):
try: try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams # type: ignore
if llm is None: if llm is None:
llm = LLM(model=model) llm = LLM(model=model)
return llm, SamplingParams return llm, SamplingParams