mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
style: fix linting errors
This commit is contained in:
parent
5ef054f105
commit
7572086231
3 changed files with 7 additions and 7 deletions
|
@ -5,7 +5,7 @@ from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable
|
from typing import Callable, Dict, List, Any
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
@ -265,7 +265,7 @@ def completion(
|
||||||
content = ""
|
content = ""
|
||||||
for chunk in streamed_response:
|
for chunk in streamed_response:
|
||||||
content += chunk["choices"][0]["delta"]["content"]
|
content += chunk["choices"][0]["delta"]["content"]
|
||||||
completion_response = [{"generated_text": content}]
|
completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input_text,
|
input=input_text,
|
||||||
|
@ -298,10 +298,10 @@ def completion(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if task == "conversational":
|
if task == "conversational":
|
||||||
if len(completion_response["generated_text"]) > 0:
|
if len(completion_response["generated_text"]) > 0: # type: ignore
|
||||||
model_response["choices"][0]["message"][
|
model_response["choices"][0]["message"][
|
||||||
"content"
|
"content"
|
||||||
] = completion_response["generated_text"]
|
] = completion_response["generated_text"] # type: ignore
|
||||||
elif task == "text-generation-inference":
|
elif task == "text-generation-inference":
|
||||||
if len(completion_response[0]["generated_text"]) > 0:
|
if len(completion_response[0]["generated_text"]) > 0:
|
||||||
model_response["choices"][0]["message"][
|
model_response["choices"][0]["message"][
|
||||||
|
@ -360,7 +360,7 @@ def embedding(
|
||||||
model_response=None,
|
model_response=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=None)
|
||||||
# print_verbose(f"{model}, {task}")
|
# print_verbose(f"{model}, {task}")
|
||||||
embed_url = ""
|
embed_url = ""
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
|
|
|
@ -130,7 +130,7 @@ def completion(
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from petals import AutoDistributedModelForCausalLM
|
from petals import AutoDistributedModelForCausalLM # type: ignore
|
||||||
except:
|
except:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
||||||
|
|
|
@ -18,7 +18,7 @@ class VLLMError(Exception):
|
||||||
# check if vllm is installed
|
# check if vllm is installed
|
||||||
def validate_environment(model: str, llm: Any =None):
|
def validate_environment(model: str, llm: Any =None):
|
||||||
try:
|
try:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams # type: ignore
|
||||||
if llm is None:
|
if llm is None:
|
||||||
llm = LLM(model=model)
|
llm = LLM(model=model)
|
||||||
return llm, SamplingParams
|
return llm, SamplingParams
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue