test fixes

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 13:04:24 -04:00
parent 969952ac8d
commit 6aedfc2201
2 changed files with 5 additions and 5 deletions

View file

@ -245,7 +245,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk]: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
@ -275,7 +275,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk]: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
@ -319,7 +319,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request) params = await self._get_params(request)
stream = await client.chat.completions.create(**params) stream = await client.chat.completions.create(**params)
@ -336,7 +336,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
r = await self.client.completions.create(**params) r = await self.client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk]: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None assert self.client is not None
params = await self._get_params(request) params = await self._get_params(request)

View file

@ -337,7 +337,7 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response( async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,