refactor: fixing linting issues

This commit is contained in:
Krrish Dholakia 2023-11-11 18:52:28 -08:00
parent ae35c13015
commit 45b6f8b853
25 changed files with 223 additions and 133 deletions

View file

@ -2,16 +2,22 @@ from typing import Optional, Union
import types
import httpx
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
from typing import Callable, Optional
import aiohttp
class OpenAIError(Exception):
def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
self.status_code = status_code
self.message = message
self.request = request
self.response = response
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
@ -264,13 +270,13 @@ class OpenAIChatCompletion(BaseLLM):
model: str
):
with self._client_session.stream(
url=f"{api_base}",
url=f"{api_base}", # type: ignore
json=data,
headers=headers,
method="POST"
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
@ -292,7 +298,7 @@ class OpenAIChatCompletion(BaseLLM):
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
@ -383,7 +389,7 @@ class OpenAITextCompletion(BaseLLM):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError(message="Error in response object format")
raise ValueError("Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
@ -406,11 +412,11 @@ class OpenAITextCompletion(BaseLLM):
raise e
def completion(self,
model: Optional[str]=None,
messages: Optional[list]=None,
model_response: Optional[ModelResponse]=None,
model_response: ModelResponse,
api_key: str,
model: str,
messages: list,
print_verbose: Optional[Callable]=None,
api_key: Optional[str]=None,
api_base: Optional[str]=None,
logging_obj=None,
acompletion: bool = False,
@ -449,7 +455,7 @@ class OpenAITextCompletion(BaseLLM):
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model)
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
@ -459,7 +465,7 @@ class OpenAITextCompletion(BaseLLM):
headers=headers,
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text, request=self._client_session.request, response=response)
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
@ -521,7 +527,7 @@ class OpenAITextCompletion(BaseLLM):
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -542,7 +548,7 @@ class OpenAITextCompletion(BaseLLM):
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: