fix(openai.py): handling extra headers

This commit is contained in:
Krrish Dholakia 2023-11-16 12:48:14 -08:00
parent 9e072f87bd
commit a94c09c13c
6 changed files with 98 additions and 118 deletions

View file

@ -5,7 +5,8 @@ from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object, Usage
from typing import Callable, Optional
import aiohttp, requests
import litellm, openai
import litellm
from openai import OpenAI, AsyncOpenAI
class OpenAIError(Exception):
def __init__(self, status_code, message, request: Optional[httpx.Request]=None, response: Optional[httpx.Response]=None):
@ -154,46 +155,9 @@ class OpenAITextCompletionConfig():
and v is not None}
class OpenAIChatCompletion(BaseLLM):
openai_client: openai.Client
openai_aclient: openai.AsyncClient
def __init__(self) -> None:
super().__init__()
self.openai_client = openai.OpenAI()
self.openai_aclient = openai.AsyncOpenAI()
def validate_environment(self, api_key, api_base, headers):
if headers is None:
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self.openai_client.api_key = api_key
self.openai_aclient.api_key = api_key
if api_base:
if self.openai_client.base_url is None or self.openai_client.base_url != api_base:
if api_base.endswith("/"):
self.openai_client._base_url = httpx.URL(url=api_base)
else:
self.openai_client._base_url = httpx.URL(url=api_base+"/")
if self.openai_aclient.base_url is None or self.openai_aclient.base_url != api_base:
if api_base.endswith("/"):
self.openai_aclient._base_url = httpx.URL(url=api_base)
else:
self.openai_aclient._base_url = httpx.URL(url=api_base+"/")
return headers
def _retry_request(self, *args, **kwargs):
self._num_retry_httpx_errors -= 1
time.sleep(1)
original_function = kwargs.pop("original_function")
return original_function(*args, **kwargs)
def completion(self,
model_response: ModelResponse,
@ -211,7 +175,8 @@ class OpenAIChatCompletion(BaseLLM):
super().completion()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key=api_key, api_base=api_base, headers=headers)
if headers:
optional_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
@ -232,13 +197,14 @@ class OpenAIChatCompletion(BaseLLM):
try:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model)
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
else:
return self.acompletion(data=data, model_response=model_response)
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model)
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
else:
response = self.openai_client.chat.completions.create(**data) # type: ignore
openai_client = OpenAI(api_key=api_key, base_url=api_base)
response = openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
@ -267,10 +233,13 @@ class OpenAIChatCompletion(BaseLLM):
async def acompletion(self,
data: dict,
model_response: ModelResponse):
model_response: ModelResponse,
api_base: str,
api_key: str):
response = None
try:
response = await self.openai_aclient.chat.completions.create(**data)
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if response and hasattr(response, "text"):
@ -281,9 +250,12 @@ class OpenAIChatCompletion(BaseLLM):
def streaming(self,
logging_obj,
data: dict,
model: str
model: str,
api_key: str,
api_base: str
):
response = self.openai_client.chat.completions.create(**data)
openai_client = OpenAI(api_key=api_key, base_url=api_base)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
@ -291,8 +263,11 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self,
logging_obj,
data: dict,
model: str):
response = await self.openai_aclient.chat.completions.create(**data)
model: str,
api_key: str,
api_base: str):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
@ -309,8 +284,7 @@ class OpenAIChatCompletion(BaseLLM):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, api_base=api_base, headers=None)
api_base = f"{api_base}/embeddings"
openai_client = OpenAI(api_key=api_key, api_base=api_base)
model = model
data = {
"model": model,
@ -325,7 +299,7 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self.openai_client.embeddings.create(**data) # type: ignore
response = openai_client.embeddings.create(**data) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,