test: fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-09 16:50:43 -08:00
parent e12bff6d7f
commit b9e6989e41
4 changed files with 7 additions and 8 deletions

View file

@ -147,7 +147,7 @@ class AzureChatCompletion(BaseLLM):
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else: else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,

View file

@ -222,7 +222,7 @@ class OpenAIChatCompletion(BaseLLM):
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else: else:
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -276,11 +276,9 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, async def acompletion(self,
logging_obj,
api_base: str, api_base: str,
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse):
model: str):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers) as response: async with session.post(api_base, json=data, headers=headers) as response:
response_json = await response.json() response_json = await response.json()

View file

@ -77,7 +77,7 @@ openai_text_completions = OpenAITextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
async def acompletion(model: str, messages: List = [], *args, **kwargs): async def acompletion(*args, **kwargs):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -117,7 +117,8 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs["messages"]
### INITIALIZE LOGGING OBJECT ### ### INITIALIZE LOGGING OBJECT ###
kwargs["litellm_call_id"] = str(uuid.uuid4()) kwargs["litellm_call_id"] = str(uuid.uuid4())
start_time = datetime.datetime.now() start_time = datetime.datetime.now()

View file

@ -199,7 +199,7 @@ def save_params_to_config(data: dict):
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
config = {} config = {}
server_settings = {} server_settings: dict = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file: with open(config_file_path, 'r') as file: