Move async with SEMAPHORE inside the async methods

This commit is contained in:
Ashwin Bharambe 2024-10-08 16:53:05 -07:00
parent 4540d8bd87
commit 216e7eb4d5

View file

@ -85,7 +85,6 @@ class MetaReferenceInferenceImpl(Inference):
if SEMAPHORE.locked(): if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported") raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE:
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
@ -94,6 +93,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
@ -120,14 +120,18 @@ class MetaReferenceInferenceImpl(Inference):
logprobs.append( logprobs.append(
TokenLogProbs( TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]} logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
) )
) )
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=message, completion_message=message,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
@ -136,6 +140,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -213,7 +218,9 @@ class MetaReferenceInferenceImpl(Inference):
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0 parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: if ipython and not parsed_tool_calls: