Move async with SEMAPHORE inside the async methods

This commit is contained in:
Ashwin Bharambe 2024-10-08 16:53:05 -07:00
parent 4540d8bd87
commit 216e7eb4d5

View file

@ -85,7 +85,6 @@ class MetaReferenceInferenceImpl(Inference):
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE:
if request.stream:
return self._stream_chat_completion(request)
else:
@ -94,6 +93,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
tokens = []
@ -120,14 +120,18 @@ class MetaReferenceInferenceImpl(Inference):
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
@ -136,6 +140,7 @@ class MetaReferenceInferenceImpl(Inference):
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
@ -213,7 +218,9 @@ class MetaReferenceInferenceImpl(Inference):
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: