fix(openai.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-16 11:01:20 -08:00
parent 6b14c8d2de
commit a23c0a2599
4 changed files with 14 additions and 27 deletions

View file

@ -369,7 +369,7 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore

View file

@ -154,8 +154,8 @@ class OpenAITextCompletionConfig():
and v is not None}
class OpenAIChatCompletion(BaseLLM):
openai_client: Optional[openai.Client] = None
openai_aclient: Optional[openai.AsyncClient] = None
openai_client: openai.Client
openai_aclient: openai.AsyncClient
def __init__(self) -> None:
super().__init__()
@ -232,13 +232,13 @@ class OpenAIChatCompletion(BaseLLM):
try:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
return self.async_streaming(logging_obj=logging_obj, data=data, model=model)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
return self.acompletion(data=data, model_response=model_response)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
return self.streaming(logging_obj=logging_obj, data=data, model=model)
else:
response = self.openai_client.chat.completions.create(**data)
response = self.openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
@ -266,8 +266,7 @@ class OpenAIChatCompletion(BaseLLM):
raise e
async def acompletion(self,
api_base: str,
data: dict, headers: dict,
data: dict,
model_response: ModelResponse):
response = None
try:
@ -281,10 +280,7 @@ class OpenAIChatCompletion(BaseLLM):
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
response = self.openai_client.chat.completions.create(**data)
@ -294,10 +290,7 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
response = await self.openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
@ -315,10 +308,8 @@ class OpenAIChatCompletion(BaseLLM):
optional_params=None,):
super().embedding()
exception_mapping_worked = False
if self._client_session is None:
self._client_session = self.create_client_session()
try:
headers = self.validate_environment(api_key)
headers = self.validate_environment(api_key, api_base=api_base, headers=None)
api_base = f"{api_base}/embeddings"
model = model
data = {
@ -334,9 +325,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
api_base, headers=headers, json=data, timeout=litellm.request_timeout
)
response = self.openai_client.embeddings.create(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
@ -345,9 +334,7 @@ class OpenAIChatCompletion(BaseLLM):
original_response=response,
)
if response.status_code!=200:
raise OpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
embedding_response = json.loads(response.model_dump_json())
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(

View file

@ -50,7 +50,7 @@ def test_async_response():
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())
test_async_response()
def test_async_anyscale_response():
import asyncio
litellm.set_verbose = True

View file

@ -20,7 +20,7 @@ def test_openai_embedding():
# print(f"response: {str(response)}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_openai_embedding()
test_openai_embedding()
def test_openai_azure_embedding_simple():
try: