diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 8b822058f..c7b865ebf 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -42,10 +42,10 @@ class InferenceClient(Inference): async def shutdown(self) -> None: pass - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -66,48 +66,57 @@ class InferenceClient(Inference): stream=stream, logprobs=logprobs, ) + if stream: + return self._stream_chat_completion(request) + else: + return self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: async with httpx.AsyncClient() as client: - if stream: - async with client.stream( - "POST", - f"{self.base_url}/inference/chat_completion", - json=encodable_dict(request), - headers={"Content-Type": "application/json"}, - timeout=20, - ) as response: - if response.status_code != 200: - content = await response.aread() - cprint( - f"Error: HTTP {response.status_code} {content.decode()}", - "red", - ) - return + response = await client.post( + f"{self.base_url}/inference/chat_completion", + json=encodable_dict(request), + headers={"Content-Type": "application/json"}, + timeout=20, + ) - async for line in response.aiter_lines(): - if line.startswith("data:"): - data = line[len("data: ") :] - try: - if "error" in data: - cprint(data, "red") - continue + response.raise_for_status() + j = response.json() + yield ChatCompletionResponse(**j) - yield ChatCompletionResponseStreamChunk( - **json.loads(data) - ) - except Exception as e: - print(data) - print(f"Error with parsing or validation: {e}") - else: - response = await client.post( - f"{self.base_url}/inference/chat_completion", - json=encodable_dict(request), - headers={"Content-Type": "application/json"}, - timeout=20, - ) + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + f"{self.base_url}/inference/chat_completion", + json=encodable_dict(request), + headers={"Content-Type": "application/json"}, + timeout=20, + ) as response: + if response.status_code != 200: + content = await response.aread() + cprint( + f"Error: HTTP {response.status_code} {content.decode()}", + "red", + ) + return - response.raise_for_status() - j = response.json() - yield ChatCompletionResponse(**j) + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: + if "error" in data: + cprint(data, "red") + continue + + yield ChatCompletionResponseStreamChunk(**json.loads(data)) + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") async def run_main( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7ff70a2af..13a51bc59 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -180,8 +180,10 @@ class ModelStore(Protocol): class Inference(Protocol): model_store: ModelStore + # This method is not `async def` because it can result in either an + # `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/completion") - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -190,8 +192,10 @@ class Inference(Protocol): logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... + # This method is not `async def` because it can result in either an + # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 361cee3f3..cf62da1d0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -70,7 +70,7 @@ class InferenceRouter(Inference): async def register_model(self, model: ModelDef) -> None: await self.routing_table.register_model(model) - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -91,27 +91,32 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) - # TODO: we need to fix streaming response to align provider implementations with Protocol. - async for chunk in self.routing_table.get_provider_impl(model).chat_completion( - **params - ): - yield chunk + provider = self.routing_table.get_provider_impl(model) + if stream: + return (chunk async for chunk in provider.chat_completion(**params)) + else: + return provider.chat_completion(**params) - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - return await self.routing_table.get_provider_impl(model).completion( + ) -> AsyncGenerator: + provider = self.routing_table.get_provider_impl(model) + params = dict( model=model, content=content, sampling_params=sampling_params, stream=stream, logprobs=logprobs, ) + if stream: + return (chunk async for chunk in provider.completion(**params)) + else: + return provider.completion(**params) async def embeddings( self, diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 40a3f5977..80d2ad4c8 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): return options - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - request = ChatCompletionRequest( - model=model, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, - ) - - messages = augment_messages_for_tools(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) - - # accumulate sampling params and other options to pass to ollama - options = self.get_ollama_chat_options(request) - ollama_model = self.map_to_provider_model(request.model) + ollama_model = self.map_to_provider_model(model) res = await self.client.ps() need_model_pull = True @@ -123,133 +106,166 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): status["status"] == "success" ), f"Failed to pull model {self.model} in ollama" - common_params = { - "model": ollama_model, + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return self._nonstream_chat_completion(request) + + def _get_params(self, request: ChatCompletionRequest) -> dict: + messages = augment_messages_for_tools(request) + model_input = self.formatter.encode_dialog_prompt(messages) + prompt = self.tokenizer.decode(model_input.tokens) + + # accumulate sampling params and other options to pass to ollama + options = self.get_ollama_chat_options(request) + + return { + "model": self.map_to_provider_model(request.model), "prompt": prompt, "options": options, "raw": True, "stream": request.stream, } - if not request.stream: - r = await self.client.generate(**common_params) - stop_reason = None - if r["done"]: - if r["done_reason"] == "stop": + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = await self.client.generate(**params) + stop_reason = None + if r["done"]: + if r["done_reason"] == "stop": + stop_reason = StopReason.end_of_turn + elif r["done_reason"] == "length": + stop_reason = StopReason.out_of_tokens + + completion_message = self.formatter.decode_assistant_message_from_content( + r["response"], stop_reason + ) + return ChatCompletionResponse( + completion_message=completion_message, + logprobs=None, + ) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.generate(**params) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + buffer = "" + ipython = False + stop_reason = None + + async for chunk in stream: + if chunk["done"]: + if stop_reason is None and chunk["done_reason"] == "stop": stop_reason = StopReason.end_of_turn - elif r["done_reason"] == "length": + elif stop_reason is None and chunk["done_reason"] == "length": stop_reason = StopReason.out_of_tokens + break - completion_message = self.formatter.decode_assistant_message_from_content( - r["response"], stop_reason - ) - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - stream = await self.client.generate(**common_params) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - if chunk["done"]: - if stop_reason is None and chunk["done_reason"] == "stop": - stop_reason = StopReason.end_of_turn - elif stop_reason is None and chunk["done_reason"] == "length": - stop_reason = StopReason.out_of_tokens - break - - text = chunk["response"] - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: + text = chunk["response"] + # check if its a tool call ( aka starts with <|python_tag|> ) + if not ipython and text.startswith("<|python_tag|>"): + ipython = True yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( content="", - parse_status=ToolCallParseStatus.failure, + parse_status=ToolCallParseStatus.started, ), - stop_reason=stop_reason, ) ) + buffer += text + continue + + if ipython: + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + + buffer += text + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) - for tool_call in message.tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), + delta=delta, + stop_reason=stop_reason, + ) + ) + else: + buffer += text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=text, stop_reason=stop_reason, ) ) + # parse tool calls and report errors + message = self.formatter.decode_assistant_message_from_content( + buffer, stop_reason + ) + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), stop_reason=stop_reason, ) ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index a310a479a..ad8cc31fd 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference): async def shutdown(self) -> None: self.generator.stop() - # hm, when stream=False, we should not be doing SSE :/ which is what the - # top-level server is going to do. make the typing more specific here - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" + # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, @@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference): logprobs=logprobs, ) - messages = augment_messages_for_tools(request) model = resolve_model(request.model) if model is None: raise RuntimeError( @@ -87,138 +87,163 @@ class MetaReferenceInferenceImpl(Inference): async with SEMAPHORE: if request.stream: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", + return self._stream_chat_completion(request) + else: + return self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + messages = augment_messages_for_tools(request) + + tokens = [] + logprobs = [] + stop_reason = None + + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) + + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append( + TokenLogProbs( + logprobs_by_token={token_result.text: token_result.logprobs[0]} ) ) - tokens = [] - logprobs = [] + if stop_reason is None: + stop_reason = StopReason.out_of_tokens - stop_reason = None + message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) + return ChatCompletionResponse( + completion_message=message, + logprobs=logprobs if request.logprobs else None, + ) - buffer = "" - ipython = False + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + messages = augment_messages_for_tools(request) - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): - buffer += token_result.text - tokens.append(token_result.token) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) - if not ipython and buffer.startswith("<|python_tag|>"): - ipython = True - if request.stream: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) + tokens = [] + logprobs = [] + stop_reason = None + ipython = False - buffer = buffer[len("<|python_tag|>") :] - continue + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) - if not request.stream: - if request.logprobs: - assert ( - len(token_result.logprobs) == 1 - ), "Expected logprob to contain 1 result for the current token" - assert ( - request.logprobs.top_k == 1 - ), "Only top_k=1 is supported for LogProbConfig" - - logprobs.append( - TokenLogProbs( - logprobs_by_token={ - token_result.text: token_result.logprobs[0] - } - ) - ) - - continue - - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, + if not ipython and token_result.text.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), ) - else: - delta = text + ) + continue - if stop_reason is None: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + if ipython: + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + else: + delta = text if stop_reason is None: - stop_reason = StopReason.out_of_tokens + if request.logprobs: + assert len(token_result.logprobs) == 1 - # TODO(ashwin): parse tool calls separately here and report errors? - # if someone breaks the iteration before coming here we are toast - message = self.generator.formatter.decode_assistant_message( - tokens, stop_reason - ) - if request.stream: - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } ) ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", + event_type=ChatCompletionResponseEventType.progress, + delta=delta, stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, ) ) - # TODO(ashwin): what else do we need to send out here when everything finishes? - else: - yield ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) + + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, ) + ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + )