Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator: def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -66,48 +66,57 @@ class InferenceClient(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
if stream: response = await client.post(
async with client.stream( f"{self.base_url}/inference/chat_completion",
"POST", json=encodable_dict(request),
f"{self.base_url}/inference/chat_completion", headers={"Content-Type": "application/json"},
json=encodable_dict(request), timeout=20,
headers={"Content-Type": "application/json"}, )
timeout=20,
) as response:
if response.status_code != 200:
content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
async for line in response.aiter_lines(): response.raise_for_status()
if line.startswith("data:"): j = response.json()
data = line[len("data: ") :] yield ChatCompletionResponse(**j)
try:
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk( async def _stream_chat_completion(
**json.loads(data) self, request: ChatCompletionRequest
) ) -> AsyncGenerator:
except Exception as e: async with httpx.AsyncClient() as client:
print(data) async with client.stream(
print(f"Error with parsing or validation: {e}") "POST",
else: f"{self.base_url}/inference/chat_completion",
response = await client.post( json=encodable_dict(request),
f"{self.base_url}/inference/chat_completion", headers={"Content-Type": "application/json"},
json=encodable_dict(request), timeout=20,
headers={"Content-Type": "application/json"}, ) as response:
timeout=20, if response.status_code != 200:
) content = await response.aread()
cprint(
f"Error: HTTP {response.status_code} {content.decode()}",
"red",
)
return
response.raise_for_status() async for line in response.aiter_lines():
j = response.json() if line.startswith("data:"):
yield ChatCompletionResponse(**j) data = line[len("data: ") :]
try:
if "error" in data:
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main( async def run_main(

View file

@ -180,8 +180,10 @@ class ModelStore(Protocol):
class Inference(Protocol): class Inference(Protocol):
model_store: ModelStore model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -190,8 +192,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model) await self.routing_table.register_model(model)
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
# TODO: we need to fix streaming response to align provider implementations with Protocol. provider = self.routing_table.get_provider_impl(model)
async for chunk in self.routing_table.get_provider_impl(model).chat_completion( if stream:
**params return (chunk async for chunk in provider.chat_completion(**params))
): else:
yield chunk return provider.chat_completion(**params)
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> AsyncGenerator:
return await self.routing_table.get_provider_impl(model).completion( provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model, model=model,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings( async def embeddings(
self, self,

View file

@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
return options return options
async def chat_completion( def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( ollama_model = self.map_to_provider_model(model)
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
res = await self.client.ps() res = await self.client.ps()
need_model_pull = True need_model_pull = True
@ -123,133 +106,166 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
status["status"] == "success" status["status"] == "success"
), f"Failed to pull model {self.model} in ollama" ), f"Failed to pull model {self.model} in ollama"
common_params = { request = ChatCompletionRequest(
"model": ollama_model, model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt, "prompt": prompt,
"options": options, "options": options,
"raw": True, "raw": True,
"stream": request.stream, "stream": request.stream,
} }
if not request.stream: async def _nonstream_chat_completion(
r = await self.client.generate(**common_params) self, request: ChatCompletionRequest
stop_reason = None ) -> ChatCompletionResponse:
if r["done"]: params = self._get_params(request)
if r["done_reason"] == "stop": r = await self.client.generate(**params)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["response"], stop_reason
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.generate(**params)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length": elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break
completion_message = self.formatter.decode_assistant_message_from_content( text = chunk["response"]
r["response"], stop_reason # check if its a tool call ( aka starts with <|python_tag|> )
) if not ipython and text.startswith("<|python_tag|>"):
yield ChatCompletionResponse( ipython = True
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.generate(**common_params)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content="", content="",
parse_status=ToolCallParseStatus.failure, parse_status=ToolCallParseStatus.started,
), ),
stop_reason=stop_reason,
) )
) )
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=delta,
content=tool_call, stop_reason=stop_reason,
parse_status=ToolCallParseStatus.success, )
), )
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.progress,
delta="", delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the def chat_completion(
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
messages = augment_messages_for_tools(request)
model = resolve_model(request.model) model = resolve_model(request.model)
if model is None: if model is None:
raise RuntimeError( raise RuntimeError(
@ -87,138 +87,163 @@ class MetaReferenceInferenceImpl(Inference):
async with SEMAPHORE: async with SEMAPHORE:
if request.stream: if request.stream:
yield ChatCompletionResponseStreamChunk( return self._stream_chat_completion(request)
event=ChatCompletionResponseEvent( else:
event_type=ChatCompletionResponseEventType.start, return self._nonstream_chat_completion(request)
delta="",
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
messages = augment_messages_for_tools(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
) )
) )
tokens = [] if stop_reason is None:
logprobs = [] stop_reason = StopReason.out_of_tokens
stop_reason = None message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
buffer = "" async def _stream_chat_completion(
ipython = False self, request: ChatCompletionRequest
) -> AsyncGenerator:
messages = augment_messages_for_tools(request)
for token_result in self.generator.chat_completion( yield ChatCompletionResponseStreamChunk(
messages=messages, event=ChatCompletionResponseEvent(
temperature=request.sampling_params.temperature, event_type=ChatCompletionResponseEventType.start,
top_p=request.sampling_params.top_p, delta="",
max_gen_len=request.sampling_params.max_tokens, )
logprobs=request.logprobs, )
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"): tokens = []
ipython = True logprobs = []
if request.stream: stop_reason = None
yield ChatCompletionResponseStreamChunk( ipython = False
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :] for token_result in self.generator.chat_completion(
continue messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if not request.stream: if not ipython and token_result.text.startswith("<|python_tag|>"):
if request.logprobs: ipython = True
assert ( yield ChatCompletionResponseStreamChunk(
len(token_result.logprobs) == 1 event=ChatCompletionResponseEvent(
), "Expected logprob to contain 1 result for the current token" event_type=ChatCompletionResponseEventType.progress,
assert ( delta=ToolCallDelta(
request.logprobs.top_k == 1 content="",
), "Only top_k=1 is supported for LogProbConfig" parse_status=ToolCallParseStatus.started,
),
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
) )
else: )
delta = text continue
if stop_reason is None: if token_result.text == "<|eot_id|>":
yield ChatCompletionResponseStreamChunk( stop_reason = StopReason.end_of_turn
event=ChatCompletionResponseEvent( text = ""
event_type=ChatCompletionResponseEventType.progress, elif token_result.text == "<|eom_id|>":
delta=delta, stop_reason = StopReason.end_of_message
stop_reason=stop_reason, text = ""
) else:
) text = token_result.text
if ipython:
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = text
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens if request.logprobs:
assert len(token_result.logprobs) == 1
# TODO(ashwin): parse tool calls separately here and report errors? logprobs.append(
# if someone breaks the iteration before coming here we are toast TokenLogProbs(
message = self.generator.formatter.decode_assistant_message( logprobs_by_token={
tokens, stop_reason token_result.text: token_result.logprobs[0]
) }
if request.stream:
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
) )
) )
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.progress,
delta="", delta=delta,
stop_reason=stop_reason, stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
) )
) )
# TODO(ashwin): what else do we need to send out here when everything finishes? if stop_reason is None:
else: stop_reason = StopReason.out_of_tokens
yield ChatCompletionResponse(
completion_message=message, message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
logprobs=logprobs if request.logprobs else None,
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
) )
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)