mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 16:53:12 +00:00
This gets the fireworks provider passing 100% of our OpenAI API
verification tests when run against a Llama Stack server using the
fireworks provider. Testing against Fireworks directly, without Llama
Stack in the middle, has a lower pass rate.
The main changes are are in how we divert Llama model OpenAI chat
completion requests to the Llama Stack chat completion API (vs
OpenAI), which applies all the client-side formatting necessary to get
tool calls working properly on Fireworks.
A side-effect of this work is any provider using the
OpenAIChatCompletionToLlamaStackMixin (renamed from
OpenAIChatCompletioonUnsupportedMixin) will also get a better
conversion from OpenAI to Llama Stack, including streaming and
non-stream responses.
A small change was required to
`llama_stack/models/llama/llama3/tool_utils.py` to get tests to 100%
because code there was incorrectly assuming any JSON response with a
`name` key was a tool call response. One of our verification tests
produces JSON keys with a `name` key that is not a tool call response,
so I tightened up the logic there to require both a `name` and
`parameters` key in the JSON response before it gets considered a
potential tool call. The `parameters` key was required by the code
anyway, but it wasn't explicitly checking for its existence.
Lastly, this adds some new verification test configs so we can see the
results of using OpenAI APIs against SaaS services directly compared
to hitting Llama Stack with a remote provider pointing at that SaaS
service.
You can run these verification tests like:
```
llama stack run \
--image-type venv \
tests/verifications/openai-api-verification-run.yaml
python tests/verifications/generate_report.py \
--run-tests \
--provider together fireworks openai \
together-llama-stack \
fireworks-llama-stack \
openai-llama-stack
```
Signed-off-by: Ben Browning <bbrownin@redhat.com>
628 lines
24 KiB
Python
628 lines
24 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from pydantic import BaseModel
|
|
from termcolor import cprint
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
ToolCallParseStatus,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
BatchChatCompletionResponse,
|
|
BatchCompletionResponse,
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
Inference,
|
|
InterleavedContent,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
StopReason,
|
|
TokenLogProbs,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
from llama_stack.models.llama.sku_types import ModelFamily
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
|
SentenceTransformerEmbeddingMixin,
|
|
)
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
build_hf_repo_model_entry,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
augment_content_with_response_format_prompt,
|
|
chat_completion_request_to_messages,
|
|
convert_request_to_raw,
|
|
)
|
|
|
|
from .config import MetaReferenceInferenceConfig
|
|
from .generators import LlamaGenerator
|
|
from .model_parallel import LlamaModelParallelGenerator
|
|
|
|
log = get_logger(__name__, category="inference")
|
|
# there's a single model parallel process running serving the model. for now,
|
|
# we don't support multiple concurrent requests to this process.
|
|
SEMAPHORE = asyncio.Semaphore(1)
|
|
|
|
|
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
|
return LlamaGenerator(config, model_id, llama_model)
|
|
|
|
|
|
class MetaReferenceInferenceImpl(
|
|
OpenAICompletionToLlamaStackMixin,
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
SentenceTransformerEmbeddingMixin,
|
|
Inference,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
|
self.config = config
|
|
self.model_id = None
|
|
self.llama_model = None
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
if self.config.create_distributed_process_group:
|
|
self.generator.stop()
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
llama_model = (
|
|
resolve_model(model.metadata["llama_model"])
|
|
if "llama_model" in model.metadata
|
|
else resolve_model(model.identifier)
|
|
)
|
|
if llama_model is None:
|
|
raise ValueError(
|
|
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
|
)
|
|
|
|
self.model_registry_helper = ModelRegistryHelper(
|
|
[
|
|
build_hf_repo_model_entry(
|
|
llama_model.descriptor(),
|
|
llama_model.core_model_id.value,
|
|
)
|
|
],
|
|
)
|
|
model = await self.model_registry_helper.register_model(model)
|
|
|
|
if model.model_type == ModelType.embedding:
|
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
|
|
|
# TODO: what is this?! you can't really specify skipping via model metadata
|
|
# kill this madness
|
|
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
|
return model
|
|
|
|
await self.load_model(model.identifier, llama_model)
|
|
return model
|
|
|
|
async def load_model(self, model_id, llama_model) -> None:
|
|
log.info(f"Loading model `{model_id}`")
|
|
|
|
builder_params = [self.config, model_id, llama_model]
|
|
|
|
if self.config.create_distributed_process_group:
|
|
self.generator = LlamaModelParallelGenerator(
|
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
|
builder_fn=llama_builder_fn,
|
|
builder_params=builder_params,
|
|
formatter=(
|
|
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
|
if llama_model.model_family == ModelFamily.llama4
|
|
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
|
),
|
|
)
|
|
self.generator.start()
|
|
else:
|
|
self.generator = llama_builder_fn(*builder_params)
|
|
|
|
self.model_id = model_id
|
|
self.llama_model = llama_model
|
|
|
|
log.info("Warming up...")
|
|
await self.completion(
|
|
model_id=model_id,
|
|
content="Hello, world!",
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
)
|
|
await self.chat_completion(
|
|
model_id=model_id,
|
|
messages=[UserMessage(content="Hi how are you?")],
|
|
sampling_params=SamplingParams(max_tokens=20),
|
|
)
|
|
log.info("Warmed up!")
|
|
|
|
def check_model(self, request) -> None:
|
|
if self.model_id is None or self.llama_model is None:
|
|
raise RuntimeError(
|
|
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
|
)
|
|
elif request.model != self.model_id:
|
|
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
content = augment_content_with_response_format_prompt(response_format, content)
|
|
request = CompletionRequest(
|
|
model=model_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
self.check_model(request)
|
|
request = await convert_request_to_raw(request)
|
|
|
|
if request.stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
results = await self._nonstream_completion([request])
|
|
return results[0]
|
|
|
|
async def batch_completion(
|
|
self,
|
|
model_id: str,
|
|
content_batch: List[InterleavedContent],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> BatchCompletionResponse:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
content_batch = [
|
|
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
|
]
|
|
|
|
request_batch = []
|
|
for content in content_batch:
|
|
request = CompletionRequest(
|
|
model=model_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
self.check_model(request)
|
|
request = await convert_request_to_raw(request)
|
|
request_batch.append(request)
|
|
|
|
results = await self._nonstream_completion(request_batch)
|
|
return BatchCompletionResponse(batch=results)
|
|
|
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
def impl():
|
|
stop_reason = None
|
|
|
|
for token_result in self.generator.completion(request):
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
logprobs = None
|
|
if stop_reason is None:
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
|
|
|
yield CompletionResponseStreamChunk(
|
|
delta=text,
|
|
stop_reason=stop_reason,
|
|
logprobs=logprobs if request.logprobs else None,
|
|
)
|
|
|
|
if stop_reason is None:
|
|
yield CompletionResponseStreamChunk(
|
|
delta="",
|
|
stop_reason=StopReason.out_of_tokens,
|
|
)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
for x in impl():
|
|
yield x
|
|
else:
|
|
for x in impl():
|
|
yield x
|
|
|
|
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
first_request = request_batch[0]
|
|
|
|
class ItemState(BaseModel):
|
|
tokens: List[int] = []
|
|
logprobs: List[TokenLogProbs] = []
|
|
stop_reason: StopReason | None = None
|
|
finished: bool = False
|
|
|
|
def impl():
|
|
states = [ItemState() for _ in request_batch]
|
|
|
|
results = []
|
|
for token_results in self.generator.completion(request_batch):
|
|
for result in token_results:
|
|
idx = result.batch_idx
|
|
state = states[idx]
|
|
if state.finished or result.ignore_token:
|
|
continue
|
|
|
|
state.finished = result.finished
|
|
if first_request.logprobs:
|
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
|
|
|
state.tokens.append(result.token)
|
|
if result.token == tokenizer.eot_id:
|
|
state.stop_reason = StopReason.end_of_turn
|
|
elif result.token == tokenizer.eom_id:
|
|
state.stop_reason = StopReason.end_of_message
|
|
|
|
for state in states:
|
|
if state.stop_reason is None:
|
|
state.stop_reason = StopReason.out_of_tokens
|
|
|
|
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
|
state.tokens = state.tokens[:-1]
|
|
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
|
results.append(
|
|
CompletionResponse(
|
|
content=content,
|
|
stop_reason=state.stop_reason,
|
|
logprobs=state.logprobs if first_request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
return impl()
|
|
else:
|
|
return impl()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> AsyncGenerator:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
request = ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config or ToolConfig(),
|
|
)
|
|
self.check_model(request)
|
|
|
|
# augment and rewrite messages depending on the model
|
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
|
# download media and convert to raw content so we can send it to the model
|
|
request = await convert_request_to_raw(request)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
if SEMAPHORE.locked():
|
|
raise RuntimeError("Only one concurrent request is supported")
|
|
|
|
if request.stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
results = await self._nonstream_chat_completion([request])
|
|
return results[0]
|
|
|
|
async def batch_chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages_batch: List[List[Message]],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> BatchChatCompletionResponse:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
request_batch = []
|
|
for messages in messages_batch:
|
|
request = ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
response_format=response_format,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config or ToolConfig(),
|
|
)
|
|
self.check_model(request)
|
|
|
|
# augment and rewrite messages depending on the model
|
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
|
# download media and convert to raw content so we can send it to the model
|
|
request = await convert_request_to_raw(request)
|
|
request_batch.append(request)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
if SEMAPHORE.locked():
|
|
raise RuntimeError("Only one concurrent request is supported")
|
|
|
|
results = await self._nonstream_chat_completion(request_batch)
|
|
return BatchChatCompletionResponse(batch=results)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request_batch: List[ChatCompletionRequest]
|
|
) -> List[ChatCompletionResponse]:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
first_request = request_batch[0]
|
|
|
|
class ItemState(BaseModel):
|
|
tokens: List[int] = []
|
|
logprobs: List[TokenLogProbs] = []
|
|
stop_reason: StopReason | None = None
|
|
finished: bool = False
|
|
|
|
def impl():
|
|
states = [ItemState() for _ in request_batch]
|
|
|
|
for token_results in self.generator.chat_completion(request_batch):
|
|
first = token_results[0]
|
|
if not first.finished and not first.ignore_token:
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
|
cprint(first.text, "cyan", end="")
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
|
cprint(f"<{first.token}>", "magenta", end="")
|
|
|
|
for result in token_results:
|
|
idx = result.batch_idx
|
|
state = states[idx]
|
|
if state.finished or result.ignore_token:
|
|
continue
|
|
|
|
state.finished = result.finished
|
|
if first_request.logprobs:
|
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
|
|
|
state.tokens.append(result.token)
|
|
if result.token == tokenizer.eot_id:
|
|
state.stop_reason = StopReason.end_of_turn
|
|
elif result.token == tokenizer.eom_id:
|
|
state.stop_reason = StopReason.end_of_message
|
|
|
|
results = []
|
|
for state in states:
|
|
if state.stop_reason is None:
|
|
state.stop_reason = StopReason.out_of_tokens
|
|
|
|
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
|
results.append(
|
|
ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=raw_message.content,
|
|
stop_reason=raw_message.stop_reason,
|
|
tool_calls=raw_message.tool_calls,
|
|
),
|
|
logprobs=state.logprobs if first_request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
return impl()
|
|
else:
|
|
return impl()
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
def impl():
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
)
|
|
)
|
|
|
|
tokens = []
|
|
logprobs = []
|
|
stop_reason = None
|
|
ipython = False
|
|
|
|
for token_result in self.generator.chat_completion(request):
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
|
cprint(token_result.text, "cyan", end="")
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
|
cprint(f"<{token_result.token}>", "magenta", end="")
|
|
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
|
|
|
tokens.append(token_result.token)
|
|
|
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
|
ipython = True
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call="",
|
|
parse_status=ToolCallParseStatus.started,
|
|
),
|
|
)
|
|
)
|
|
continue
|
|
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
if ipython:
|
|
delta = ToolCallDelta(
|
|
tool_call=text,
|
|
parse_status=ToolCallParseStatus.in_progress,
|
|
)
|
|
else:
|
|
delta = TextDelta(text=text)
|
|
|
|
if stop_reason is None:
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=delta,
|
|
stop_reason=stop_reason,
|
|
logprobs=logprobs if request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
if stop_reason is None:
|
|
stop_reason = StopReason.out_of_tokens
|
|
|
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
|
|
|
parsed_tool_calls = len(message.tool_calls) > 0
|
|
if ipython and not parsed_tool_calls:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call="",
|
|
parse_status=ToolCallParseStatus.failed,
|
|
),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
for tool_call in message.tool_calls:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call=tool_call,
|
|
parse_status=ToolCallParseStatus.succeeded,
|
|
),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
for x in impl():
|
|
yield x
|
|
else:
|
|
for x in impl():
|
|
yield x
|