openai - return response headers

This commit is contained in:
Ishaan Jaff 2024-07-20 15:04:27 -07:00
parent 966733ed22
commit c07b8d9575

View file

@ -960,6 +960,7 @@ class OpenAIChatCompletion(BaseLLM):
return convert_to_model_response_object( return convert_to_model_response_object(
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
response_headers=headers,
) )
except Exception as e: except Exception as e:
if print_verbose is not None: if print_verbose is not None:
@ -1108,6 +1109,7 @@ class OpenAIChatCompletion(BaseLLM):
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None), stream_options=data.get("stream_options", None),
response_headers=headers,
) )
return streamwrapper return streamwrapper
@ -1201,7 +1203,7 @@ class OpenAIChatCompletion(BaseLLM):
except Exception as e: except Exception as e:
raise e raise e
async def make_sync_openai_embedding_request( def make_sync_openai_embedding_request(
self, self,
openai_client: OpenAI, openai_client: OpenAI,
data: dict, data: dict,
@ -1217,6 +1219,7 @@ class OpenAIChatCompletion(BaseLLM):
raw_response = openai_client.embeddings.with_raw_response.create( raw_response = openai_client.embeddings.with_raw_response.create(
**data, timeout=timeout **data, timeout=timeout
) # type: ignore ) # type: ignore
headers = dict(raw_response.headers) headers = dict(raw_response.headers)
response = raw_response.parse() response = raw_response.parse()
return headers, response return headers, response
@ -1321,9 +1324,9 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
) )
## COMPLETION CALL ## embedding CALL
headers: Optional[Dict] = None headers: Optional[Dict] = None
headers, response = self.make_sync_openai_embedding_request( headers, sync_embedding_response = self.make_sync_openai_embedding_request(
openai_client=openai_client, data=data, timeout=timeout openai_client=openai_client, data=data, timeout=timeout
) # type: ignore ) # type: ignore
@ -1333,9 +1336,14 @@ class OpenAIChatCompletion(BaseLLM):
input=input, input=input,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=sync_embedding_response,
) )
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore return convert_to_model_response_object(
response_object=sync_embedding_response.model_dump(),
model_response_object=model_response,
response_headers=headers,
response_type="embedding",
) # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e