pass through track usage for streaming endpoints

This commit is contained in:
Ishaan Jaff 2024-09-02 16:11:20 -07:00
parent 73d0a78444
commit a6d4a27207
2 changed files with 26 additions and 7 deletions

View file

@ -11,6 +11,8 @@ from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_stu
)
from litellm.types.utils import GenericStreamingChunk
from .success_handler import PassThroughEndpointLogging
class ModelIteratorType(Enum):
VERTEX_AI = "vertexAI"
@ -28,6 +30,7 @@ def get_litellm_chunk(
custom_stream_wrapper: litellm.utils.CustomStreamWrapper,
chunk_dict: Dict,
) -> Optional[Dict]:
generic_chunk: GenericStreamingChunk = model_iterator.chunk_parser(chunk_dict)
if generic_chunk:
return custom_stream_wrapper.chunk_creator(chunk=generic_chunk)
@ -39,6 +42,8 @@ async def chunk_processor(
litellm_logging_obj: LiteLLMLoggingObj,
iterator_type: ModelIteratorType,
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
) -> AsyncIterable[bytes]:
IteratorClass = MODEL_ITERATORS[iterator_type]
@ -84,11 +89,21 @@ async def chunk_processor(
except json.JSONDecodeError:
pass
complete_streaming_response = litellm.stream_chunk_builder(chunks=all_chunks)
end_time = datetime.now()
await litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
complete_streaming_response: litellm.ModelResponse = litellm.stream_chunk_builder(
chunks=all_chunks
)
end_time = datetime.now()
if passthrough_success_handler_obj.is_vertex_route(url_route):
_model = passthrough_success_handler_obj.extract_model_from_url(url_route)
complete_streaming_response.model = _model
litellm_logging_obj.model = _model
litellm_logging_obj.model_call_details["model"] = _model
asyncio.create_task(
litellm_logging_obj.async_success_handler(
result=complete_streaming_response,
start_time=start_time,
end_time=end_time,
)
)