pass through track usage for streaming endpoints

This commit is contained in:
Ishaan Jaff 2024-09-02 16:11:20 -07:00
parent ef6b90a657
commit 3f9c58507e
2 changed files with 26 additions and 7 deletions

View file

@ -426,6 +426,8 @@ async def pass_through_request(
litellm_logging_obj=logging_obj, litellm_logging_obj=logging_obj,
iterator_type=ModelIteratorType.VERTEX_AI, iterator_type=ModelIteratorType.VERTEX_AI,
start_time=start_time, start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
): ):
yield chunk yield chunk
@ -468,6 +470,8 @@ async def pass_through_request(
litellm_logging_obj=logging_obj, litellm_logging_obj=logging_obj,
iterator_type=ModelIteratorType.VERTEX_AI, iterator_type=ModelIteratorType.VERTEX_AI,
start_time=start_time, start_time=start_time,
passthrough_success_handler_obj=pass_through_endpoint_logging,
url_route=str(url),
): ):
yield chunk yield chunk

View file

@ -11,6 +11,8 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
) )
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from .success_handler import PassThroughEndpointLogging
class ModelIteratorType(Enum): class ModelIteratorType(Enum):
VERTEX_AI = "vertexAI" VERTEX_AI = "vertexAI"
@ -28,6 +30,7 @@ def get_litellm_chunk(
custom_stream_wrapper: litellm.utils.CustomStreamWrapper, custom_stream_wrapper: litellm.utils.CustomStreamWrapper,
chunk_dict: Dict, chunk_dict: Dict,
) -> Optional[Dict]: ) -> Optional[Dict]:
generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict) generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict)
if generic_chunk: if generic_chunk:
return custom_stream_wrapper.chunk_creator(chunk=generic_chunk) return custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
@ -39,6 +42,8 @@ async def chunk_processor(
litellm_logging_obj: LiteLLMLoggingObj, litellm_logging_obj: LiteLLMLoggingObj,
iterator_type: ModelIteratorType, iterator_type: ModelIteratorType,
start_time: datetime, start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
) -> AsyncIterable[bytes]: ) -> AsyncIterable[bytes]:
IteratorClass = MODEL_ITERATORS[iterator_type] IteratorClass = MODEL_ITERATORS[iterator_type]
@ -84,11 +89,21 @@ async def chunk_processor(
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
complete_streaming_response = litellm.stream_chunk_builder(chunks=all_chunks) complete_streaming_response: litellm.ModelResponse = litellm.stream_chunk_builder(
chunks=all_chunks
)
end_time = datetime.now() end_time = datetime.now()
await litellm_logging_obj.async_success_handler(
if passthrough_success_handler_obj.is_vertex_route(url_route):
_model = passthrough_success_handler_obj.extract_model_from_url(url_route)
complete_streaming_response.model = _model
litellm_logging_obj.model = _model
litellm_logging_obj.model_call_details["model"] = _model
asyncio.create_task(
litellm_logging_obj.async_success_handler(
result=complete_streaming_response, result=complete_streaming_response,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
)