From 216e7eb4d5f47290a94f3b97fd3eac7439aab4dc Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 8 Oct 2024 16:53:05 -0700 Subject: [PATCH] Move `async with SEMAPHORE` inside the async methods --- .../meta_reference/inference/inference.py | 269 +++++++++--------- 1 file changed, 138 insertions(+), 131 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 9e31f0834..43a131647 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -85,112 +85,36 @@ class MetaReferenceInferenceImpl(Inference): if SEMAPHORE.locked(): raise RuntimeError("Only one concurrent request is supported") - async with SEMAPHORE: - if request.stream: - return self._stream_chat_completion(request) - else: - return self._nonstream_chat_completion(request) + if request.stream: + return self._stream_chat_completion(request) + else: + return self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - messages = chat_completion_request_to_messages(request) + async with SEMAPHORE: + messages = chat_completion_request_to_messages(request) - tokens = [] - logprobs = [] - stop_reason = None + tokens = [] + logprobs = [] + stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): - tokens.append(token_result.token) + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs.append( - TokenLogProbs( - logprobs_by_token={token_result.text: token_result.logprobs[0]} - ) - ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - return ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, - ) - - async def _stream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator: - messages = chat_completion_request_to_messages(request) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - - tokens = [] - logprobs = [] - stop_reason = None - ipython = False - - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): - tokens.append(token_result.token) - - if not ipython and token_result.text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - continue - - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = text - - if stop_reason is None: if request.logprobs: assert len(token_result.logprobs) == 1 @@ -201,49 +125,132 @@ class MetaReferenceInferenceImpl(Inference): } ) ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.generator.formatter.decode_assistant_message( + tokens, stop_reason + ) + return ChatCompletionResponse( + completion_message=message, + logprobs=logprobs if request.logprobs else None, + ) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + async with SEMAPHORE: + messages = chat_completion_request_to_messages(request) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + tokens = [] + logprobs = [] + stop_reason = None + ipython = False + + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) + + if not ipython and token_result.text.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + continue + + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + if ipython: + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + else: + delta = text + + if stop_reason is None: + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.generator.formatter.decode_assistant_message( + tokens, stop_reason + ) + + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=delta, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), stop_reason=stop_reason, - logprobs=logprobs if request.logprobs else None, ) ) - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), + event_type=ChatCompletionResponseEventType.complete, + delta="", stop_reason=stop_reason, ) ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - )