fix(openai.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-16 11:01:20 -08:00
parent 6b14c8d2de
commit a23c0a2599
4 changed files with 14 additions and 27 deletions

View file

@ -369,7 +369,7 @@ from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig, AmazonLlamaConfig
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig from .llms.azure import AzureOpenAIConfig
from .main import * # type: ignore from .main import * # type: ignore

View file

@ -154,8 +154,8 @@ class OpenAITextCompletionConfig():
and v is not None} and v is not None}
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
openai_client: Optional[openai.Client] = None openai_client: openai.Client
openai_aclient: Optional[openai.AsyncClient] = None openai_aclient: openai.AsyncClient
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -232,13 +232,13 @@ class OpenAIChatCompletion(BaseLLM):
try: try:
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, data=data, model=model)
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) return self.acompletion(data=data, model_response=model_response)
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.streaming(logging_obj=logging_obj, data=data, model=model)
else: else:
response = self.openai_client.chat.completions.create(**data) response = self.openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
@ -266,8 +266,7 @@ class OpenAIChatCompletion(BaseLLM):
raise e raise e
async def acompletion(self, async def acompletion(self,
api_base: str, data: dict,
data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
response = None response = None
try: try:
@ -281,10 +280,7 @@ class OpenAIChatCompletion(BaseLLM):
def streaming(self, def streaming(self,
logging_obj, logging_obj,
api_base: str,
data: dict, data: dict,
headers: dict,
model_response: ModelResponse,
model: str model: str
): ):
response = self.openai_client.chat.completions.create(**data) response = self.openai_client.chat.completions.create(**data)
@ -294,10 +290,7 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self, async def async_streaming(self,
logging_obj, logging_obj,
api_base: str,
data: dict, data: dict,
headers: dict,
model_response: ModelResponse,
model: str): model: str):
response = await self.openai_aclient.chat.completions.create(**data) response = await self.openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
@ -315,10 +308,8 @@ class OpenAIChatCompletion(BaseLLM):
optional_params=None,): optional_params=None,):
super().embedding() super().embedding()
exception_mapping_worked = False exception_mapping_worked = False
if self._client_session is None:
self._client_session = self.create_client_session()
try: try:
headers = self.validate_environment(api_key) headers = self.validate_environment(api_key, api_base=api_base, headers=None)
api_base = f"{api_base}/embeddings" api_base = f"{api_base}/embeddings"
model = model model = model
data = { data = {
@ -334,9 +325,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = self.openai_client.embeddings.create(**data) # type: ignore
api_base, headers=headers, json=data, timeout=litellm.request_timeout
)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -344,10 +333,8 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,
) )
if response.status_code!=200: embedding_response = json.loads(response.model_dump_json())
raise OpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = [] output_data = []
for idx, embedding in enumerate(embedding_response["data"]): for idx, embedding in enumerate(embedding_response["data"]):
output_data.append( output_data.append(

View file

@ -50,7 +50,7 @@ def test_async_response():
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
test_async_response()
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -20,7 +20,7 @@ def test_openai_embedding():
# print(f"response: {str(response)}") # print(f"response: {str(response)}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_openai_embedding() test_openai_embedding()
def test_openai_azure_embedding_simple(): def test_openai_azure_embedding_simple():
try: try: