From 627edaf407fa253102f2e64cba5157f26bfd362c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 16:50:57 -0700 Subject: [PATCH] Make all API methods `async def` again --- llama_stack/apis/inference/client.py | 7 +- llama_stack/apis/inference/inference.py | 6 +- llama_stack/distribution/routers/routers.py | 12 ++-- .../adapters/inference/bedrock/bedrock.py | 4 +- .../inference/databricks/databricks.py | 6 +- .../adapters/inference/fireworks/fireworks.py | 6 +- .../adapters/inference/ollama/ollama.py | 6 +- .../providers/adapters/inference/tgi/tgi.py | 6 +- .../adapters/inference/together/together.py | 4 +- .../meta_reference/agents/agent_instance.py | 2 +- .../meta_reference/inference/generation.py | 49 ++++++-------- .../meta_reference/inference/inference.py | 65 ++++++++----------- .../inference/model_parallel.py | 50 +++++++------- .../inference/parallel_utils.py | 30 ++++----- .../meta_reference/safety/llama_guard.py | 2 +- llama_stack/providers/impls/vllm/vllm.py | 6 +- .../tests/inference/test_inference.py | 4 +- 17 files changed, 120 insertions(+), 145 deletions(-) diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 79d2cc02c..90636fa36 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -42,10 +42,10 @@ class InferenceClient(Inference): async def shutdown(self) -> None: pass - def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -139,7 +139,8 @@ async def run_main( else: logprobs_config = None - iterator = client.chat_completion( + assert stream, "Non streaming not supported here" + iterator = await client.chat_completion( model=model, messages=[message], stream=stream, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 588dd37ca..449fff50d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -181,10 +181,8 @@ class ModelStore(Protocol): class Inference(Protocol): model_store: ModelStore - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/completion") - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -196,7 +194,7 @@ class Inference(Protocol): # This method is not `async def` because it can result in either an # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index cf62da1d0..a78e808d0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -70,7 +70,7 @@ class InferenceRouter(Inference): async def register_model(self, model: ModelDef) -> None: await self.routing_table.register_model(model) - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -93,11 +93,11 @@ class InferenceRouter(Inference): ) provider = self.routing_table.get_provider_impl(model) if stream: - return (chunk async for chunk in provider.chat_completion(**params)) + return (chunk async for chunk in await provider.chat_completion(**params)) else: - return provider.chat_completion(**params) + return await provider.chat_completion(**params) - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -114,9 +114,9 @@ class InferenceRouter(Inference): logprobs=logprobs, ) if stream: - return (chunk async for chunk in provider.completion(**params)) + return (chunk async for chunk in await provider.completion(**params)) else: - return provider.completion(**params) + return await provider.completion(**params) async def embeddings( self, diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 22f87ef6b..8440ecc20 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: self.client.close() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) return tool_config - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 141051186..9f50ad227 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index c82012cba..537f3a6b4 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c50c869fd..3a3e4b451 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return ret - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) def _get_params(self, request: ChatCompletionRequest) -> dict: return { diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index cd0afad0c..3c610099c 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 750ca126e..8c73d75ec 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -64,7 +64,7 @@ class TogetherInferenceAdapter( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ class TogetherInferenceAdapter( if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Together diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0d334fdad..cbc7490fd 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin): stop_reason = None with tracing.span("inference"): - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=self._get_tools(), diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 20a8addc7..46a409ebe 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3.api.datatypes import ( - InterleavedTextMedia, - Message, - ToolPromptFormat, -) from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model from pydantic import BaseModel from termcolor import cprint +from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig @@ -297,15 +296,11 @@ class Llama: if all(eos_reached): break - def text_completion( + def completion( self, - content: InterleavedTextMedia, - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - echo: bool = False, + request: CompletionRequest, ) -> Generator: + sampling_params = request.sampling_params if ( max_gen_len is None or max_gen_len == 0 @@ -313,26 +308,24 @@ class Llama: ): max_gen_len = self.model.params.max_seq_len - 1 - model_input = self.formatter.encode_content(content) - + model_input = self.formatter.encode_content(request.content) yield from self.generate( model_input=model_input, max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - echo=echo, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), + echo=False, ) def chat_completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: ChatCompletionRequest, ) -> Generator: + messages = chat_completion_request_to_messages(request) + + sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -343,12 +336,12 @@ class Llama: yield from self.generate( model_input=self.formatter.encode_dialog_prompt( messages, - tool_prompt_format, + request.tool_prompt_format, ), max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), include_stop_token=True, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 7edc279d0..bdccb4f03 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, -) from .config import MetaReferenceInferenceConfig from .generation import Llama @@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): if self.config.create_distributed_process_group: self.generator.stop() - def completion( + def check_model(self, request) -> None: + model = resolve_model(request.model) + if model is None: + raise RuntimeError( + f"Unknown model: {request.model}, Run `llama model list`" + ) + elif model.descriptor() != self.model.descriptor(): + raise RuntimeError( + f"Model mismatch: {request.model} != {self.model.descriptor()}" + ) + + async def completion( self, model: str, content: InterleavedTextMedia, @@ -66,9 +74,19 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - def chat_completion( + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + self.check_model(request) + + async def chat_completion( self, model: str, messages: List[Message], @@ -93,16 +111,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stream=stream, logprobs=logprobs, ) - - model = resolve_model(request.model) - if model is None: - raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" - ) - elif model.descriptor() != self.model.descriptor(): - raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" - ) + self.check_model(request) if self.config.create_distributed_process_group: if SEMAPHORE.locked(): @@ -111,26 +120,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): if request.stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: def impl(): - messages = chat_completion_request_to_messages(request) - tokens = [] logprobs = [] stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if token_result.text == "<|eot_id|>": @@ -170,8 +170,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest ) -> AsyncGenerator: def impl(): - messages = chat_completion_request_to_messages(request) - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -184,14 +182,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stop_reason = None ipython = False - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if not ipython and token_result.text.startswith("<|python_tag|>"): diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index e8f483f30..7e7831185 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -7,16 +7,17 @@ import os from copy import deepcopy from functools import partial -from typing import Generator, List, Optional +from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from .config import MetaReferenceInferenceConfig from .generation import Llama, model_checkpoint_dir -from .parallel_utils import InferenceArgs, ModelParallelProcessGroup +from .parallel_utils import ModelParallelProcessGroup class ModelRunner: @@ -24,15 +25,13 @@ class ModelRunner: self.llama = llama # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` - def __call__(self, task: InferenceArgs): - return self.llama.chat_completion( - task.messages, - task.temperature, - task.top_p, - task.max_gen_len, - task.logprobs, - task.tool_prompt_format, - ) + def __call__(self, req: Any): + if isinstance(req, ChatCompletionRequest): + return self.llama.chat_completion(req) + elif isinstance(req, CompletionRequest): + return self.llama.completion(req) + else: + raise ValueError(f"Unexpected task type {type(req)}") def init_model_cb(config: MetaReferenceInferenceConfig): @@ -77,23 +76,18 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - def chat_completion( + def completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: CompletionRequest, ) -> Generator: - req_obj = InferenceArgs( - messages=deepcopy(messages), - temperature=temperature, - top_p=top_p, - max_gen_len=max_gen_len, - logprobs=logprobs or False, - tool_prompt_format=tool_prompt_format, - ) - + req_obj = deepcopy(request) + gen = self.group.run_inference(req_obj) + yield from gen + + def chat_completion( + self, + request: ChatCompletionRequest, + ) -> Generator: + req_obj = deepcopy(request) gen = self.group.run_inference(req_obj) yield from gen diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index 7dbedd0f0..62eeefaac 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -4,6 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import json import multiprocessing import os @@ -11,10 +17,9 @@ import tempfile import time import uuid from enum import Enum -from typing import Callable, Generator, List, Literal, Optional, Union +from typing import Callable, Generator, Literal, Optional, Union import torch - import zmq from fairscale.nn.model_parallel.initialize import ( @@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat - from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from .generation import TokenResult -class InferenceArgs(BaseModel): - messages: List[Message] - temperature: float - top_p: float - max_gen_len: int - logprobs: bool - tool_prompt_format: ToolPromptFormat - - class ProcessingMessageName(str, Enum): ready_request = "ready_request" ready_response = "ready_response" @@ -80,7 +76,7 @@ class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ( ProcessingMessageName.task_request ) - task: InferenceArgs + task: Union[CompletionRequest, ChatCompletionRequest] class TaskResponse(BaseModel): @@ -349,11 +345,13 @@ class ModelParallelProcessGroup: self.process.join() self.started = False - def run_inference(self, inference_args: InferenceArgs) -> Generator: + def run_inference( + self, req: Union[CompletionRequest, ChatCompletionRequest] + ) -> Generator: assert not self.running, "inference already running" self.running = True - self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) + self.request_socket.send(encode_msg(TaskRequest(task=req))) try: while True: obj_json = self.request_socket.recv() diff --git a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py index a6f450fae..99b1c29be 100644 --- a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py @@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase): # TODO: llama-stack inference protocol has issues with non-streaming inference code content = "" - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( model=self.model, messages=[shield_input_message], stream=True, diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index 5cdb1a2ab..c977c738d 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if self.engine: self.engine.shutdown_background_loop() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): logprobs=logprobs, ) - def chat_completion( + async def chat_completion( self, model: str, messages: list[Message], @@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, results_generator) else: - return self._nonstream_chat_completion(request, results_generator) + return await self._nonstream_chat_completion(request, results_generator) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 581a0d428..385e7efb9 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -146,7 +146,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=sample_messages, stream=True, **inference_settings["common_params"], @@ -217,7 +217,7 @@ async def test_chat_completion_with_tool_calling_streaming( response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=messages, tools=[sample_tool_definition], stream=True,