diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 8a07ac9a8e..964bbfb70c 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -45,6 +45,7 @@ from litellm.types.llms.openai import ( Batch, FineTuningJob, HttpxBinaryResponseContent, + ResponseCompletedEvent, ResponsesAPIResponse, ) from litellm.types.rerank import RerankResponse @@ -854,6 +855,7 @@ class Logging(LiteLLMLoggingBaseClass): Batch, FineTuningJob, ResponsesAPIResponse, + ResponseCompletedEvent, ], cache_hit: Optional[bool] = None, ) -> Optional[float]: @@ -1000,9 +1002,7 @@ class Logging(LiteLLMLoggingBaseClass): ## if model in model cost map - log the response cost ## else set cost to None if ( - standard_logging_object is None - and result is not None - and self.stream is not True + standard_logging_object is None and result is not None ): # handle streaming separately if ( isinstance(result, ModelResponse) @@ -1016,6 +1016,7 @@ class Logging(LiteLLMLoggingBaseClass): or isinstance(result, FineTuningJob) or isinstance(result, LiteLLMBatch) or isinstance(result, ResponsesAPIResponse) + or isinstance(result, ResponseCompletedEvent) ): ## HIDDEN PARAMS ## hidden_params = getattr(result, "_hidden_params", {}) diff --git a/litellm/responses/streaming_iterator.py b/litellm/responses/streaming_iterator.py index b99d81309a..7464d04607 100644 --- a/litellm/responses/streaming_iterator.py +++ b/litellm/responses/streaming_iterator.py @@ -9,6 +9,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.types.llms.openai import ( ResponsesAPIResponse, + ResponsesAPIStreamEvents, ResponsesAPIStreamingResponse, ) from litellm.utils import CustomStreamWrapper @@ -82,21 +83,20 @@ class ResponsesAPIStreamingIterator: if ( openai_responses_api_chunk and openai_responses_api_chunk.type - == COMPLETED_OPENAI_CHUNK_TYPE + == ResponsesAPIStreamEvents.RESPONSE_COMPLETED ): self.completed_response = openai_responses_api_chunk - await self.logging_obj.async_success_handler( - result=self.completed_response, - start_time=self.start_time, - end_time=datetime.now(), - cache_hit=None, + asyncio.create_task( + self.logging_obj.async_success_handler( + result=self.completed_response, + start_time=self.start_time, + end_time=datetime.now(), + cache_hit=None, + ) ) return openai_responses_api_chunk - return ResponsesAPIStreamingResponse( - type="response", response=parsed_chunk - ) - + return await self.__anext__() except json.JSONDecodeError: # If we can't parse the chunk, continue to the next one return await self.__anext__() diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 5d6adc263e..9e55ac30e9 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -852,44 +852,56 @@ class ResponsesAPIStreamEvents(str, Enum): # Base streaming response types -class ResponseCreatedEvent(BaseModel): +class BaseResponseAPIStreamEvent(BaseModel): + def __getitem__(self, key): + return self.__dict__[key] + + def get(self, key, default=None): + return self.__dict__.get(key, default) + + def __contains__(self, key): + return key in self.__dict__ + + +class ResponseCreatedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.RESPONSE_CREATED] response: ResponsesAPIResponse -class ResponseInProgressEvent(BaseModel): +class ResponseInProgressEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS] response: ResponsesAPIResponse -class ResponseCompletedEvent(BaseModel): +class ResponseCompletedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.RESPONSE_COMPLETED] response: ResponsesAPIResponse + _hidden_params: dict = PrivateAttr(default_factory=dict) -class ResponseFailedEvent(BaseModel): +class ResponseFailedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.RESPONSE_FAILED] response: ResponsesAPIResponse -class ResponseIncompleteEvent(BaseModel): +class ResponseIncompleteEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE] response: ResponsesAPIResponse -class OutputItemAddedEvent(BaseModel): +class OutputItemAddedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED] output_index: int item: dict -class OutputItemDoneEvent(BaseModel): +class OutputItemDoneEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE] output_index: int item: dict -class ContentPartAddedEvent(BaseModel): +class ContentPartAddedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.CONTENT_PART_ADDED] item_id: str output_index: int @@ -897,7 +909,7 @@ class ContentPartAddedEvent(BaseModel): part: dict -class ContentPartDoneEvent(BaseModel): +class ContentPartDoneEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.CONTENT_PART_DONE] item_id: str output_index: int @@ -905,7 +917,7 @@ class ContentPartDoneEvent(BaseModel): part: dict -class OutputTextDeltaEvent(BaseModel): +class OutputTextDeltaEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA] item_id: str output_index: int @@ -913,7 +925,7 @@ class OutputTextDeltaEvent(BaseModel): delta: str -class OutputTextAnnotationAddedEvent(BaseModel): +class OutputTextAnnotationAddedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED] item_id: str output_index: int @@ -922,7 +934,7 @@ class OutputTextAnnotationAddedEvent(BaseModel): annotation: dict -class OutputTextDoneEvent(BaseModel): +class OutputTextDoneEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE] item_id: str output_index: int @@ -930,7 +942,7 @@ class OutputTextDoneEvent(BaseModel): text: str -class RefusalDeltaEvent(BaseModel): +class RefusalDeltaEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.REFUSAL_DELTA] item_id: str output_index: int @@ -938,7 +950,7 @@ class RefusalDeltaEvent(BaseModel): delta: str -class RefusalDoneEvent(BaseModel): +class RefusalDoneEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.REFUSAL_DONE] item_id: str output_index: int @@ -946,57 +958,57 @@ class RefusalDoneEvent(BaseModel): refusal: str -class FunctionCallArgumentsDeltaEvent(BaseModel): +class FunctionCallArgumentsDeltaEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA] item_id: str output_index: int delta: str -class FunctionCallArgumentsDoneEvent(BaseModel): +class FunctionCallArgumentsDoneEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE] item_id: str output_index: int arguments: str -class FileSearchCallInProgressEvent(BaseModel): +class FileSearchCallInProgressEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS] output_index: int item_id: str -class FileSearchCallSearchingEvent(BaseModel): +class FileSearchCallSearchingEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING] output_index: int item_id: str -class FileSearchCallCompletedEvent(BaseModel): +class FileSearchCallCompletedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED] output_index: int item_id: str -class WebSearchCallInProgressEvent(BaseModel): +class WebSearchCallInProgressEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS] output_index: int item_id: str -class WebSearchCallSearchingEvent(BaseModel): +class WebSearchCallSearchingEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING] output_index: int item_id: str -class WebSearchCallCompletedEvent(BaseModel): +class WebSearchCallCompletedEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED] output_index: int item_id: str -class ErrorEvent(BaseModel): +class ErrorEvent(BaseResponseAPIStreamEvent): type: Literal[ResponsesAPIStreamEvents.ERROR] code: Optional[str] message: str