mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Separate chat_completion stream and non-stream implementations
This is a pretty important requirement. The streaming response type is an AsyncGenerator while the non-stream one is a single object. So far this has worked _sometimes_ due to various pre-existing hacks (and in some cases, just failed.)
This commit is contained in:
parent
f8752ab8dc
commit
0c9eb3341c
5 changed files with 346 additions and 287 deletions
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -66,48 +66,57 @@ class InferenceClient(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
if stream:
|
response = await client.post(
|
||||||
async with client.stream(
|
f"{self.base_url}/inference/chat_completion",
|
||||||
"POST",
|
json=encodable_dict(request),
|
||||||
f"{self.base_url}/inference/chat_completion",
|
headers={"Content-Type": "application/json"},
|
||||||
json=encodable_dict(request),
|
timeout=20,
|
||||||
headers={"Content-Type": "application/json"},
|
)
|
||||||
timeout=20,
|
|
||||||
) as response:
|
|
||||||
if response.status_code != 200:
|
|
||||||
content = await response.aread()
|
|
||||||
cprint(
|
|
||||||
f"Error: HTTP {response.status_code} {content.decode()}",
|
|
||||||
"red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
response.raise_for_status()
|
||||||
if line.startswith("data:"):
|
j = response.json()
|
||||||
data = line[len("data: ") :]
|
yield ChatCompletionResponse(**j)
|
||||||
try:
|
|
||||||
if "error" in data:
|
|
||||||
cprint(data, "red")
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
async def _stream_chat_completion(
|
||||||
**json.loads(data)
|
self, request: ChatCompletionRequest
|
||||||
)
|
) -> AsyncGenerator:
|
||||||
except Exception as e:
|
async with httpx.AsyncClient() as client:
|
||||||
print(data)
|
async with client.stream(
|
||||||
print(f"Error with parsing or validation: {e}")
|
"POST",
|
||||||
else:
|
f"{self.base_url}/inference/chat_completion",
|
||||||
response = await client.post(
|
json=encodable_dict(request),
|
||||||
f"{self.base_url}/inference/chat_completion",
|
headers={"Content-Type": "application/json"},
|
||||||
json=encodable_dict(request),
|
timeout=20,
|
||||||
headers={"Content-Type": "application/json"},
|
) as response:
|
||||||
timeout=20,
|
if response.status_code != 200:
|
||||||
)
|
content = await response.aread()
|
||||||
|
cprint(
|
||||||
|
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
response.raise_for_status()
|
async for line in response.aiter_lines():
|
||||||
j = response.json()
|
if line.startswith("data:"):
|
||||||
yield ChatCompletionResponse(**j)
|
data = line[len("data: ") :]
|
||||||
|
try:
|
||||||
|
if "error" in data:
|
||||||
|
cprint(data, "red")
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||||
|
except Exception as e:
|
||||||
|
print(data)
|
||||||
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def run_main(
|
async def run_main(
|
||||||
|
|
|
@ -180,8 +180,10 @@ class ModelStore(Protocol):
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
model_store: ModelStore
|
model_store: ModelStore
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -190,8 +192,10 @@ class Inference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
await self.routing_table.register_model(model)
|
await self.routing_table.register_model(model)
|
||||||
|
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
if stream:
|
||||||
**params
|
return (chunk async for chunk in provider.chat_completion(**params))
|
||||||
):
|
else:
|
||||||
yield chunk
|
return provider.chat_completion(**params)
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> AsyncGenerator:
|
||||||
return await self.routing_table.get_provider_impl(model).completion(
|
provider = self.routing_table.get_provider_impl(model)
|
||||||
|
params = dict(
|
||||||
model=model,
|
model=model,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return (chunk async for chunk in provider.completion(**params))
|
||||||
|
else:
|
||||||
|
return provider.completion(**params)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
|
@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
ollama_model = self.map_to_provider_model(model)
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
tools=tools or [],
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
stream=stream,
|
|
||||||
logprobs=logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
|
||||||
prompt = self.tokenizer.decode(model_input.tokens)
|
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to ollama
|
|
||||||
options = self.get_ollama_chat_options(request)
|
|
||||||
ollama_model = self.map_to_provider_model(request.model)
|
|
||||||
|
|
||||||
res = await self.client.ps()
|
res = await self.client.ps()
|
||||||
need_model_pull = True
|
need_model_pull = True
|
||||||
|
@ -123,133 +106,166 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
status["status"] == "success"
|
status["status"] == "success"
|
||||||
), f"Failed to pull model {self.model} in ollama"
|
), f"Failed to pull model {self.model} in ollama"
|
||||||
|
|
||||||
common_params = {
|
request = ChatCompletionRequest(
|
||||||
"model": ollama_model,
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
messages = augment_messages_for_tools(request)
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
|
prompt = self.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
# accumulate sampling params and other options to pass to ollama
|
||||||
|
options = self.get_ollama_chat_options(request)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": self.map_to_provider_model(request.model),
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"options": options,
|
"options": options,
|
||||||
"raw": True,
|
"raw": True,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not request.stream:
|
async def _nonstream_chat_completion(
|
||||||
r = await self.client.generate(**common_params)
|
self, request: ChatCompletionRequest
|
||||||
stop_reason = None
|
) -> ChatCompletionResponse:
|
||||||
if r["done"]:
|
params = self._get_params(request)
|
||||||
if r["done_reason"] == "stop":
|
r = await self.client.generate(**params)
|
||||||
|
stop_reason = None
|
||||||
|
if r["done"]:
|
||||||
|
if r["done_reason"] == "stop":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif r["done_reason"] == "length":
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
|
r["response"], stop_reason
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=completion_message,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
stream = await self.client.generate(**params)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
ipython = False
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk["done"]:
|
||||||
|
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif r["done_reason"] == "length":
|
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
break
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
text = chunk["response"]
|
||||||
r["response"], stop_reason
|
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||||
)
|
if not ipython and text.startswith("<|python_tag|>"):
|
||||||
yield ChatCompletionResponse(
|
ipython = True
|
||||||
completion_message=completion_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stream = await self.client.generate(**common_params)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
ipython = False
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
async for chunk in stream:
|
|
||||||
if chunk["done"]:
|
|
||||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
text = chunk["response"]
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
|
||||||
ipython = True
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
buffer += text
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ipython:
|
|
||||||
if text == "<|eot_id|>":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
text = ""
|
|
||||||
continue
|
|
||||||
elif text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
buffer += text
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buffer += text
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=text,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# parse tool calls and report errors
|
|
||||||
message = self.formatter.decode_assistant_message_from_content(
|
|
||||||
buffer, stop_reason
|
|
||||||
)
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content="",
|
content="",
|
||||||
parse_status=ToolCallParseStatus.failure,
|
parse_status=ToolCallParseStatus.started,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
buffer += text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
if text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
continue
|
||||||
|
elif text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
buffer += text
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=delta,
|
||||||
content=tool_call,
|
stop_reason=stop_reason,
|
||||||
parse_status=ToolCallParseStatus.success,
|
)
|
||||||
),
|
)
|
||||||
|
else:
|
||||||
|
buffer += text
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# parse tool calls and report errors
|
||||||
|
message = self.formatter.decode_assistant_message_from_content(
|
||||||
|
buffer, stop_reason
|
||||||
|
)
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta="",
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
def chat_completion(
|
||||||
# top-level server is going to do. make the typing more specific here
|
|
||||||
async def chat_completion(
|
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
|
@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(request.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -87,138 +87,163 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
|
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
if request.stream:
|
if request.stream:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
return self._stream_chat_completion(request)
|
||||||
event=ChatCompletionResponseEvent(
|
else:
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
return self._nonstream_chat_completion(request)
|
||||||
delta="",
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
messages = augment_messages_for_tools(request)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
logprobs = []
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
for token_result in self.generator.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
|
):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
if token_result.text == "<|eot_id|>":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif token_result.text == "<|eom_id|>":
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={token_result.text: token_result.logprobs[0]}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = []
|
if stop_reason is None:
|
||||||
logprobs = []
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
stop_reason = None
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=message,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
|
)
|
||||||
|
|
||||||
buffer = ""
|
async def _stream_chat_completion(
|
||||||
ipython = False
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
messages = augment_messages_for_tools(request)
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
messages=messages,
|
event=ChatCompletionResponseEvent(
|
||||||
temperature=request.sampling_params.temperature,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
top_p=request.sampling_params.top_p,
|
delta="",
|
||||||
max_gen_len=request.sampling_params.max_tokens,
|
)
|
||||||
logprobs=request.logprobs,
|
)
|
||||||
tool_prompt_format=request.tool_prompt_format,
|
|
||||||
):
|
|
||||||
buffer += token_result.text
|
|
||||||
tokens.append(token_result.token)
|
|
||||||
|
|
||||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
tokens = []
|
||||||
ipython = True
|
logprobs = []
|
||||||
if request.stream:
|
stop_reason = None
|
||||||
yield ChatCompletionResponseStreamChunk(
|
ipython = False
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = buffer[len("<|python_tag|>") :]
|
for token_result in self.generator.chat_completion(
|
||||||
continue
|
messages=messages,
|
||||||
|
temperature=request.sampling_params.temperature,
|
||||||
|
top_p=request.sampling_params.top_p,
|
||||||
|
max_gen_len=request.sampling_params.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
tool_prompt_format=request.tool_prompt_format,
|
||||||
|
):
|
||||||
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if not request.stream:
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
if request.logprobs:
|
ipython = True
|
||||||
assert (
|
yield ChatCompletionResponseStreamChunk(
|
||||||
len(token_result.logprobs) == 1
|
event=ChatCompletionResponseEvent(
|
||||||
), "Expected logprob to contain 1 result for the current token"
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
assert (
|
delta=ToolCallDelta(
|
||||||
request.logprobs.top_k == 1
|
content="",
|
||||||
), "Only top_k=1 is supported for LogProbConfig"
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
logprobs.append(
|
|
||||||
TokenLogProbs(
|
|
||||||
logprobs_by_token={
|
|
||||||
token_result.text: token_result.logprobs[0]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token_result.text == "<|eot_id|>":
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
text = ""
|
|
||||||
elif token_result.text == "<|eom_id|>":
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
text = ""
|
|
||||||
else:
|
|
||||||
text = token_result.text
|
|
||||||
|
|
||||||
if ipython:
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=text,
|
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
delta = text
|
continue
|
||||||
|
|
||||||
if stop_reason is None:
|
if token_result.text == "<|eot_id|>":
|
||||||
yield ChatCompletionResponseStreamChunk(
|
stop_reason = StopReason.end_of_turn
|
||||||
event=ChatCompletionResponseEvent(
|
text = ""
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
elif token_result.text == "<|eom_id|>":
|
||||||
delta=delta,
|
stop_reason = StopReason.end_of_message
|
||||||
stop_reason=stop_reason,
|
text = ""
|
||||||
)
|
else:
|
||||||
)
|
text = token_result.text
|
||||||
|
|
||||||
|
if ipython:
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=text,
|
||||||
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta = text
|
||||||
|
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
logprobs.append(
|
||||||
# if someone breaks the iteration before coming here we are toast
|
TokenLogProbs(
|
||||||
message = self.generator.formatter.decode_assistant_message(
|
logprobs_by_token={
|
||||||
tokens, stop_reason
|
token_result.text: token_result.logprobs[0]
|
||||||
)
|
}
|
||||||
if request.stream:
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
|
||||||
if ipython and not parsed_tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content="",
|
|
||||||
parse_status=ToolCallParseStatus.failure,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=tool_call,
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
),
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta="",
|
delta=delta,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
if stop_reason is None:
|
||||||
else:
|
stop_reason = StopReason.out_of_tokens
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=message,
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||||
logprobs=logprobs if request.logprobs else None,
|
|
||||||
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
|
if ipython and not parsed_tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content="",
|
||||||
|
parse_status=ToolCallParseStatus.failure,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue