mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
refactor: fixing linting issues
This commit is contained in:
parent
ae35c13015
commit
45b6f8b853
25 changed files with 223 additions and 133 deletions
|
@ -2,16 +2,22 @@ from typing import Optional, Union
|
|||
import types
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
|
||||
from typing import Callable, Optional
|
||||
import aiohttp
|
||||
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response):
|
||||
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = request
|
||||
self.response = response
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
@ -264,13 +270,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model: str
|
||||
):
|
||||
with self._client_session.stream(
|
||||
url=f"{api_base}",
|
||||
url=f"{api_base}", # type: ignore
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST"
|
||||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
|
||||
|
||||
completion_stream = response.iter_lines()
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
|
@ -292,7 +298,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text()) # type: ignore
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
@ -383,7 +389,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError(message="Error in response object format")
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list=[]
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(content=choice["text"], role="assistant")
|
||||
|
@ -406,11 +412,11 @@ class OpenAITextCompletion(BaseLLM):
|
|||
raise e
|
||||
|
||||
def completion(self,
|
||||
model: Optional[str]=None,
|
||||
messages: Optional[list]=None,
|
||||
model_response: Optional[ModelResponse]=None,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
print_verbose: Optional[Callable]=None,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
|
@ -449,7 +455,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model)
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||
else:
|
||||
|
@ -459,7 +465,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text, request=self._client_session.request, response=response)
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -521,7 +527,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
|
@ -542,7 +548,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue