diff --git a/llama_stack/providers/inline/openai_responses/openai_responses.py b/llama_stack/providers/inline/openai_responses/openai_responses.py index 2a137e5c1..5f5df6ad0 100644 --- a/llama_stack/providers/inline/openai_responses/openai_responses.py +++ b/llama_stack/providers/inline/openai_responses/openai_responses.py @@ -152,17 +152,45 @@ class OpenAIResponsesImpl(OpenAIResponses): messages.append(OpenAIUserMessageParam(content=user_content)) chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None - # TODO: the code below doesn't handle streaming chat_response = await self.inference_api.openai_chat_completion( model=model_obj.identifier, messages=messages, tools=chat_tools, stream=stream, ) - # type cast to appease mypy - chat_response = cast(OpenAIChatCompletion, chat_response) - # dump and reload to map to our pydantic types - chat_response = OpenAIChatCompletion.model_validate_json(chat_response.model_dump_json()) + + if isinstance(chat_response, AsyncIterator): + # TODO: refactor this into a separate method that handles streaming + chat_response_id = "" + chat_response_content = [] + # TODO: these chunk_ fields are hacky and only take the last chunk into account + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + async for chunk in chat_response: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # TODO: this only works for text content + chat_response_content.append(chunk_choice.delta.content or "") + chunk_finish_reason = chunk_choice.finish_reason + assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content)) + chat_response = OpenAIChatCompletion( + id=chat_response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=chunk_finish_reason, + index=0, + ) + ], + created=chunk_created, + model=chunk_model, + ) + else: + # dump and reload to map to our pydantic types + chat_response = OpenAIChatCompletion.model_validate_json(chat_response.model_dump_json()) output_messages: List[OpenAIResponseOutput] = [] if chat_response.choices[0].finish_reason == "tool_calls":