Move async with SEMAPHORE inside the async methods

This commit is contained in:
Ashwin Bharambe 2024-10-08 16:53:05 -07:00
parent 4540d8bd87
commit 216e7eb4d5

View file

@ -85,112 +85,36 @@ class MetaReferenceInferenceImpl(Inference):
if SEMAPHORE.locked(): if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported") raise RuntimeError("Only one concurrent request is supported")
async with SEMAPHORE: if request.stream:
if request.stream: return self._stream_chat_completion(request)
return self._stream_chat_completion(request) else:
else: return self._nonstream_chat_completion(request)
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
messages = chat_completion_request_to_messages(request) async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(
messages=messages, messages=messages,
temperature=request.sampling_params.temperature, temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p, top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens, max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format, tool_prompt_format=request.tool_prompt_format,
): ):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>": elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
@ -201,49 +125,132 @@ class MetaReferenceInferenceImpl(Inference):
} }
) )
) )
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=delta, delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
) )
) )
if stop_reason is None: for tool_call in message.tool_calls:
stop_reason = StopReason.out_of_tokens yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta( delta="",
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)