mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Move async with SEMAPHORE
inside the async methods
This commit is contained in:
parent
4540d8bd87
commit
216e7eb4d5
1 changed files with 138 additions and 131 deletions
|
@ -85,112 +85,36 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
raise RuntimeError("Only one concurrent request is supported")
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
async with SEMAPHORE:
|
if request.stream:
|
||||||
if request.stream:
|
return self._stream_chat_completion(request)
|
||||||
return self._stream_chat_completion(request)
|
else:
|
||||||
else:
|
return self._nonstream_chat_completion(request)
|
||||||
return self._nonstream_chat_completion(request)
|
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
messages = chat_completion_request_to_messages(request)
|
async with SEMAPHORE:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
for token_result in self.generator.chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=request.sampling_params.temperature,
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
logprobs=request.logprobs,
|
logprobs=request.logprobs,
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
):
|
):
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif token_result.text == "<|eom_id|>":
|
elif token_result.text == "<|eom_id|>":
|
||||||
stop_reason = StopReason.end_of_message
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
if request.logprobs:
|
|
||||||
assert len(token_result.logprobs) == 1
|
|
||||||
|
|
||||||
logprobs.append(
|
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={token_result.text: token_result.logprobs[0]}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
|
||||||
return ChatCompletionResponse(
|
|
||||||
completion_message=message,
|
|
||||||
logprobs=logprobs if request.logprobs else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
messages = chat_completion_request_to_messages(request)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
logprobs = []
|
|
||||||
stop_reason = None
|
|
||||||
ipython = False
|
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
|
||||||
messages=messages,
|
|
||||||
temperature=request.sampling_params.temperature,
|
|
||||||
top_p=request.sampling_params.top_p,
|
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
tokens.append(token_result.token)
|
|
||||||
|
|
||||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
|
||||||
ipython = True
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
text = ""
|
|
||||||
elif token_result.text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
else:
|
|
||||||
text = token_result.text
|
|
||||||
|
|
||||||
if ipython:
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
delta = text
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
@ -201,49 +125,132 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
|
tokens, stop_reason
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=message,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
async with SEMAPHORE:
|
||||||
|
messages = chat_completion_request_to_messages(request)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
stop_reason = None
|
||||||
|
ipython = False
|
||||||
|
|
||||||
|
for token_result in self.generator.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
|
):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
|
ipython = True
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta = text
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=delta,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_reason is None:
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
|
tokens, stop_reason
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=delta,
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop_reason is None:
|
for tool_call in message.tool_calls:
|
||||||
stop_reason = StopReason.out_of_tokens
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
delta=ToolCallDelta(
|
delta="",
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue