Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -66,8 +66,30 @@ class InferenceClient(Inference):
stream=stream,
logprobs=logprobs,
)
async with httpx.AsyncClient() as client:
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
@ -91,23 +113,10 @@ class InferenceClient(Inference):
cprint(data, "red")
continue
yield ChatCompletionResponseStreamChunk(
**json.loads(data)
)
yield ChatCompletionResponseStreamChunk(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
else:
response = await client.post(
f"{self.base_url}/inference/chat_completion",
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)
response.raise_for_status()
j = response.json()
yield ChatCompletionResponse(**j)
async def run_main(

View file

@ -180,8 +180,10 @@ class ModelStore(Protocol):
class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -190,8 +192,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion(
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings(
self,

View file

@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
ollama_model = self.map_to_provider_model(model)
res = await self.client.ps()
need_model_pull = True
@ -123,16 +106,43 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
common_params = {
"model": ollama_model,
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"options": options,
"raw": True,
"stream": request.stream,
}
if not request.stream:
r = await self.client.generate(**common_params)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
@ -143,18 +153,24 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
completion_message = self.formatter.decode_assistant_message_from_content(
r["response"], stop_reason
)
yield ChatCompletionResponse(
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.generate(**params)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.generate(**common_params)
buffer = ""
ipython = False

View file

@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference):
async def shutdown(self) -> None:
self.generator.stop()
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
@ -87,6 +87,57 @@ class MetaReferenceInferenceImpl(Inference):
async with SEMAPHORE:
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
messages = augment_messages_for_tools(request)
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
messages = augment_messages_for_tools(request)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -96,10 +147,7 @@ class MetaReferenceInferenceImpl(Inference):
tokens = []
logprobs = []
stop_reason = None
buffer = ""
ipython = False
for token_result in self.generator.chat_completion(
@ -110,12 +158,10 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)
if not ipython and buffer.startswith("<|python_tag|>"):
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -125,27 +171,6 @@ class MetaReferenceInferenceImpl(Inference):
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if not request.stream:
if request.logprobs:
assert (
len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token"
assert (
request.logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
continue
if token_result.text == "<|eot_id|>":
@ -166,23 +191,30 @@ class MetaReferenceInferenceImpl(Inference):
delta = text
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
# TODO(ashwin): parse tool calls separately here and report errors?
# if someone breaks the iteration before coming here we are toast
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
if request.stream:
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
@ -215,10 +247,3 @@ class MetaReferenceInferenceImpl(Inference):
stop_reason=stop_reason,
)
)
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)