forked from phoenix-oss/llama-stack-mirror
# What does this PR do? TLDR: Changes needed to get 100% passing tests for OpenAI API verification tests when run against Llama Stack with the `together`, `fireworks`, and `openai` providers. And `groq` is better than before, at 88% passing. This cleans up the OpenAI API support for image message types (specifically `image_url` types) and handling of the `response_format` chat completion parameter. Both of these required a few more Pydantic model definitions in our Inference API, just to move from the not-quite-right stubs I had in place to something fleshed out to match the actual OpenAI API specs. As part of testing this, I also found and fixed a bug in the litellm implementation of openai_completion and openai_chat_completion, so the providers based on those should actually be working now. The method `prepare_openai_completion_params` in `llama_stack/providers/utils/inference/openai_compat.py` was improved to actually recursively clean up input parameters, including handling of lists, dicts, and dumping of Pydantic models to dicts. These changes were required to get to 100% passing tests on the OpenAI API verification against the `openai` provider. With the above, the together.ai provider was passing as well as it is without Llama Stack. But, since we have Llama Stack in the middle, I took the opportunity to clean up the together.ai provider so that it now also passes the OpenAI API spec tests we have at 100%. That means together.ai is now passing our verification test better when using an OpenAI client talking to Llama Stack than it is when hitting together.ai directly, without Llama Stack in the middle. And, another round of work for Fireworks to improve translation of incoming OpenAI chat completion requests to Llama Stack chat completion requests gets the fireworks provider passing at 100%. The server-side fireworks.ai tool calling support with OpenAI chat completions and Llama 4 models isn't great yet, but by pointing the OpenAI clients at Llama Stack's API we can clean things up and get everything working as expected for Llama 4 models. ## Test Plan ### OpenAI API Verification Tests I ran the OpenAI API verification tests as below and 100% of the tests passed. First, start a Llama Stack server that runs the `openai` provider with the `gpt-4o` and `gpt-4o-mini` models deployed. There's not a template setup to do this out of the box, so I added a `tests/verifications/openai-api-verification-run.yaml` to do this. First, ensure you have the necessary API key environment variables set: ``` export TOGETHER_API_KEY="..." export FIREWORKS_API_KEY="..." export OPENAI_API_KEY="..." ``` Then, run a Llama Stack server that serves up all these providers: ``` llama stack run \ --image-type venv \ tests/verifications/openai-api-verification-run.yaml ``` Finally, generate a new verification report against all these providers, both with and without the Llama Stack server in the middle. ``` python tests/verifications/generate_report.py \ --run-tests \ --provider \ together \ fireworks \ groq \ openai \ together-llama-stack \ fireworks-llama-stack \ groq-llama-stack \ openai-llama-stack ``` You'll see that most of the configurations with Llama Stack in the middle now pass at 100%, even though some of them do not pass at 100% when hitting the backend provider's API directly with an OpenAI client. ### OpenAI Completion Integration Tests with vLLM: I also ran the smaller `test_openai_completion.py` test suite (that's not yet merged with the verification tests) on multiple of the providers, since I had to adjust the method signature of openai_chat_completion a bit and thus had to touch lots of these providers to match. Here's the tests I ran there, all passing: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### OpenAI Completion Integration Tests with ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ### OpenAI Completion Integration Tests with together.ai ``` INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" llama stack build --template together --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct-Turbo" ``` ### OpenAI Completion Integration Tests with fireworks.ai ``` INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" llama stack build --template fireworks --image-type venv --run ``` in another terminal ``` LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.1-8B-Instruct" --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
628 lines
24 KiB
Python
628 lines
24 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import os
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from pydantic import BaseModel
|
|
from termcolor import cprint
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
ToolCallParseStatus,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
BatchChatCompletionResponse,
|
|
BatchCompletionResponse,
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
Inference,
|
|
InterleavedContent,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
StopReason,
|
|
TokenLogProbs,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
UserMessage,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
from llama_stack.models.llama.sku_types import ModelFamily
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
|
SentenceTransformerEmbeddingMixin,
|
|
)
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
build_hf_repo_model_entry,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
OpenAICompletionToLlamaStackMixin,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
augment_content_with_response_format_prompt,
|
|
chat_completion_request_to_messages,
|
|
convert_request_to_raw,
|
|
)
|
|
|
|
from .config import MetaReferenceInferenceConfig
|
|
from .generators import LlamaGenerator
|
|
from .model_parallel import LlamaModelParallelGenerator
|
|
|
|
log = get_logger(__name__, category="inference")
|
|
# there's a single model parallel process running serving the model. for now,
|
|
# we don't support multiple concurrent requests to this process.
|
|
SEMAPHORE = asyncio.Semaphore(1)
|
|
|
|
|
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
|
return LlamaGenerator(config, model_id, llama_model)
|
|
|
|
|
|
class MetaReferenceInferenceImpl(
|
|
OpenAICompletionToLlamaStackMixin,
|
|
OpenAIChatCompletionToLlamaStackMixin,
|
|
SentenceTransformerEmbeddingMixin,
|
|
Inference,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
|
self.config = config
|
|
self.model_id = None
|
|
self.llama_model = None
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
if self.config.create_distributed_process_group:
|
|
self.generator.stop()
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
llama_model = (
|
|
resolve_model(model.metadata["llama_model"])
|
|
if "llama_model" in model.metadata
|
|
else resolve_model(model.identifier)
|
|
)
|
|
if llama_model is None:
|
|
raise ValueError(
|
|
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
|
)
|
|
|
|
self.model_registry_helper = ModelRegistryHelper(
|
|
[
|
|
build_hf_repo_model_entry(
|
|
llama_model.descriptor(),
|
|
llama_model.core_model_id.value,
|
|
)
|
|
],
|
|
)
|
|
model = await self.model_registry_helper.register_model(model)
|
|
|
|
if model.model_type == ModelType.embedding:
|
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
|
|
|
# TODO: what is this?! you can't really specify skipping via model metadata
|
|
# kill this madness
|
|
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
|
return model
|
|
|
|
await self.load_model(model.identifier, llama_model)
|
|
return model
|
|
|
|
async def load_model(self, model_id, llama_model) -> None:
|
|
log.info(f"Loading model `{model_id}`")
|
|
|
|
builder_params = [self.config, model_id, llama_model]
|
|
|
|
if self.config.create_distributed_process_group:
|
|
self.generator = LlamaModelParallelGenerator(
|
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
|
builder_fn=llama_builder_fn,
|
|
builder_params=builder_params,
|
|
formatter=(
|
|
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
|
if llama_model.model_family == ModelFamily.llama4
|
|
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
|
),
|
|
)
|
|
self.generator.start()
|
|
else:
|
|
self.generator = llama_builder_fn(*builder_params)
|
|
|
|
self.model_id = model_id
|
|
self.llama_model = llama_model
|
|
|
|
log.info("Warming up...")
|
|
await self.completion(
|
|
model_id=model_id,
|
|
content="Hello, world!",
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
)
|
|
await self.chat_completion(
|
|
model_id=model_id,
|
|
messages=[UserMessage(content="Hi how are you?")],
|
|
sampling_params=SamplingParams(max_tokens=20),
|
|
)
|
|
log.info("Warmed up!")
|
|
|
|
def check_model(self, request) -> None:
|
|
if self.model_id is None or self.llama_model is None:
|
|
raise RuntimeError(
|
|
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
|
)
|
|
elif request.model != self.model_id:
|
|
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
content = augment_content_with_response_format_prompt(response_format, content)
|
|
request = CompletionRequest(
|
|
model=model_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
self.check_model(request)
|
|
request = await convert_request_to_raw(request)
|
|
|
|
if request.stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
results = await self._nonstream_completion([request])
|
|
return results[0]
|
|
|
|
async def batch_completion(
|
|
self,
|
|
model_id: str,
|
|
content_batch: List[InterleavedContent],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> BatchCompletionResponse:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
content_batch = [
|
|
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
|
]
|
|
|
|
request_batch = []
|
|
for content in content_batch:
|
|
request = CompletionRequest(
|
|
model=model_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
self.check_model(request)
|
|
request = await convert_request_to_raw(request)
|
|
request_batch.append(request)
|
|
|
|
results = await self._nonstream_completion(request_batch)
|
|
return BatchCompletionResponse(batch=results)
|
|
|
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
def impl():
|
|
stop_reason = None
|
|
|
|
for token_result in self.generator.completion(request):
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
logprobs = None
|
|
if stop_reason is None:
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
|
|
|
yield CompletionResponseStreamChunk(
|
|
delta=text,
|
|
stop_reason=stop_reason,
|
|
logprobs=logprobs if request.logprobs else None,
|
|
)
|
|
|
|
if stop_reason is None:
|
|
yield CompletionResponseStreamChunk(
|
|
delta="",
|
|
stop_reason=StopReason.out_of_tokens,
|
|
)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
for x in impl():
|
|
yield x
|
|
else:
|
|
for x in impl():
|
|
yield x
|
|
|
|
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
first_request = request_batch[0]
|
|
|
|
class ItemState(BaseModel):
|
|
tokens: List[int] = []
|
|
logprobs: List[TokenLogProbs] = []
|
|
stop_reason: StopReason | None = None
|
|
finished: bool = False
|
|
|
|
def impl():
|
|
states = [ItemState() for _ in request_batch]
|
|
|
|
results = []
|
|
for token_results in self.generator.completion(request_batch):
|
|
for result in token_results:
|
|
idx = result.batch_idx
|
|
state = states[idx]
|
|
if state.finished or result.ignore_token:
|
|
continue
|
|
|
|
state.finished = result.finished
|
|
if first_request.logprobs:
|
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
|
|
|
state.tokens.append(result.token)
|
|
if result.token == tokenizer.eot_id:
|
|
state.stop_reason = StopReason.end_of_turn
|
|
elif result.token == tokenizer.eom_id:
|
|
state.stop_reason = StopReason.end_of_message
|
|
|
|
for state in states:
|
|
if state.stop_reason is None:
|
|
state.stop_reason = StopReason.out_of_tokens
|
|
|
|
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
|
state.tokens = state.tokens[:-1]
|
|
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
|
results.append(
|
|
CompletionResponse(
|
|
content=content,
|
|
stop_reason=state.stop_reason,
|
|
logprobs=state.logprobs if first_request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
return impl()
|
|
else:
|
|
return impl()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> AsyncGenerator:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
request = ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config or ToolConfig(),
|
|
)
|
|
self.check_model(request)
|
|
|
|
# augment and rewrite messages depending on the model
|
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
|
# download media and convert to raw content so we can send it to the model
|
|
request = await convert_request_to_raw(request)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
if SEMAPHORE.locked():
|
|
raise RuntimeError("Only one concurrent request is supported")
|
|
|
|
if request.stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
results = await self._nonstream_chat_completion([request])
|
|
return results[0]
|
|
|
|
async def batch_chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages_batch: List[List[Message]],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> BatchChatCompletionResponse:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if logprobs:
|
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
|
|
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
request_batch = []
|
|
for messages in messages_batch:
|
|
request = ChatCompletionRequest(
|
|
model=model_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
response_format=response_format,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config or ToolConfig(),
|
|
)
|
|
self.check_model(request)
|
|
|
|
# augment and rewrite messages depending on the model
|
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
|
# download media and convert to raw content so we can send it to the model
|
|
request = await convert_request_to_raw(request)
|
|
request_batch.append(request)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
if SEMAPHORE.locked():
|
|
raise RuntimeError("Only one concurrent request is supported")
|
|
|
|
results = await self._nonstream_chat_completion(request_batch)
|
|
return BatchChatCompletionResponse(batch=results)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request_batch: List[ChatCompletionRequest]
|
|
) -> List[ChatCompletionResponse]:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
first_request = request_batch[0]
|
|
|
|
class ItemState(BaseModel):
|
|
tokens: List[int] = []
|
|
logprobs: List[TokenLogProbs] = []
|
|
stop_reason: StopReason | None = None
|
|
finished: bool = False
|
|
|
|
def impl():
|
|
states = [ItemState() for _ in request_batch]
|
|
|
|
for token_results in self.generator.chat_completion(request_batch):
|
|
first = token_results[0]
|
|
if not first.finished and not first.ignore_token:
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
|
cprint(first.text, "cyan", end="")
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
|
cprint(f"<{first.token}>", "magenta", end="")
|
|
|
|
for result in token_results:
|
|
idx = result.batch_idx
|
|
state = states[idx]
|
|
if state.finished or result.ignore_token:
|
|
continue
|
|
|
|
state.finished = result.finished
|
|
if first_request.logprobs:
|
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
|
|
|
state.tokens.append(result.token)
|
|
if result.token == tokenizer.eot_id:
|
|
state.stop_reason = StopReason.end_of_turn
|
|
elif result.token == tokenizer.eom_id:
|
|
state.stop_reason = StopReason.end_of_message
|
|
|
|
results = []
|
|
for state in states:
|
|
if state.stop_reason is None:
|
|
state.stop_reason = StopReason.out_of_tokens
|
|
|
|
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
|
results.append(
|
|
ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=raw_message.content,
|
|
stop_reason=raw_message.stop_reason,
|
|
tool_calls=raw_message.tool_calls,
|
|
),
|
|
logprobs=state.logprobs if first_request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
return impl()
|
|
else:
|
|
return impl()
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
|
tokenizer = self.generator.formatter.tokenizer
|
|
|
|
def impl():
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
)
|
|
)
|
|
|
|
tokens = []
|
|
logprobs = []
|
|
stop_reason = None
|
|
ipython = False
|
|
|
|
for token_result in self.generator.chat_completion(request):
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
|
cprint(token_result.text, "cyan", end="")
|
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
|
cprint(f"<{token_result.token}>", "magenta", end="")
|
|
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
|
|
|
tokens.append(token_result.token)
|
|
|
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
|
ipython = True
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call="",
|
|
parse_status=ToolCallParseStatus.started,
|
|
),
|
|
)
|
|
)
|
|
continue
|
|
|
|
if token_result.token == tokenizer.eot_id:
|
|
stop_reason = StopReason.end_of_turn
|
|
text = ""
|
|
elif token_result.token == tokenizer.eom_id:
|
|
stop_reason = StopReason.end_of_message
|
|
text = ""
|
|
else:
|
|
text = token_result.text
|
|
|
|
if ipython:
|
|
delta = ToolCallDelta(
|
|
tool_call=text,
|
|
parse_status=ToolCallParseStatus.in_progress,
|
|
)
|
|
else:
|
|
delta = TextDelta(text=text)
|
|
|
|
if stop_reason is None:
|
|
if request.logprobs:
|
|
assert len(token_result.logprobs) == 1
|
|
|
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=delta,
|
|
stop_reason=stop_reason,
|
|
logprobs=logprobs if request.logprobs else None,
|
|
)
|
|
)
|
|
|
|
if stop_reason is None:
|
|
stop_reason = StopReason.out_of_tokens
|
|
|
|
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
|
|
|
parsed_tool_calls = len(message.tool_calls) > 0
|
|
if ipython and not parsed_tool_calls:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call="",
|
|
parse_status=ToolCallParseStatus.failed,
|
|
),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
for tool_call in message.tool_calls:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call=tool_call,
|
|
parse_status=ToolCallParseStatus.succeeded,
|
|
),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=stop_reason,
|
|
)
|
|
)
|
|
|
|
if self.config.create_distributed_process_group:
|
|
async with SEMAPHORE:
|
|
for x in impl():
|
|
yield x
|
|
else:
|
|
for x in impl():
|
|
yield x
|