introduce openai_compat with the completions (not chat-completions) API

This keeps the prompt encoding layer in our control (see
`chat_completion_request_to_prompt()` method)
This commit is contained in:
Ashwin Bharambe 2024-10-08 12:15:55 -07:00 committed by Ashwin Bharambe
parent 0c9eb3341c
commit 05e73d12b3
6 changed files with 354 additions and 513 deletions

View file

@ -10,14 +10,19 @@ from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
) )
from .config import FireworksImplConfig from .config import FireworksImplConfig
@ -38,12 +43,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
) )
self.config = config self.config = config
self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(self.tokenizer)
@property
def client(self) -> Fireworks:
return Fireworks(api_key=self.config.api_key)
async def initialize(self) -> None: async def initialize(self) -> None:
return return
@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def completion( def completion(
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
@ -61,16 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict: def chat_completion(
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -92,154 +83,41 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs=logprobs, logprobs=logprobs,
) )
messages = augment_messages_for_tools(request) client = Fireworks(api_key=self.config.api_key)
model_input = self.formatter.encode_dialog_prompt(messages) if stream:
prompt = self.tokenizer.decode(model_input.tokens) return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt = chat_completion_request_to_prompt(request, self.formatter)
# Fireworks always prepends with BOS # Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"): if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :] prompt = prompt[len("<|begin_of_text|>") :]
# accumulate sampling params and other options to pass to fireworks options = get_sampling_options(request)
options = self.get_fireworks_chat_options(request)
options.setdefault("max_tokens", 512) options.setdefault("max_tokens", 512)
return {
fireworks_model = self.map_to_provider_model(request.model) "model": self.map_to_provider_model(request.model),
"prompt": prompt,
if not request.stream: "stream": request.stream,
r = await self.client.completion.acreate( **options,
model=fireworks_model, }
prompt=prompt,
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].text, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in self.client.completion.acreate(
model=fireworks_model,
prompt=prompt,
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if stop_reason is None and chunk.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].text
if text is None:
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -9,17 +9,22 @@ from typing import AsyncGenerator
import httpx import httpx
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
OLLAMA_SUPPORTED_MODELS = { OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
@ -30,14 +35,10 @@ OLLAMA_SUPPORTED_MODELS = {
} }
class OllamaInferenceAdapter(ModelRegistryHelper, Inference): class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
)
self.url = url self.url = url
self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(self.tokenizer)
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
@ -55,6 +56,28 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_model(self, model: ModelDef) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
)
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
def completion( def completion(
self, self,
model: str, model: str,
@ -65,20 +88,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
def chat_completion( def chat_completion(
self, self,
model: str, model: str,
@ -90,22 +99,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
ollama_model = self.map_to_provider_model(model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
@ -116,24 +109,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
if stream: if stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return self._nonstream_chat_completion(request) return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
return { return {
"model": self.map_to_provider_model(request.model), "model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": prompt, "prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": options, "options": get_sampling_options(request),
"raw": True, "raw": True,
"stream": request.stream, "stream": request.stream,
} }
@ -143,129 +128,35 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = await self.client.generate(**params) r = await self.client.generate(**params)
stop_reason = None assert isinstance(r, dict)
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content( choice = OpenAICompatCompletionChoice(
r["response"], stop_reason finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
) )
return ChatCompletionResponse( response = OpenAICompatCompletionResponse(
completion_message=completion_message, choices=[choice],
logprobs=None,
) )
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
stream = await self.client.generate(**params) async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
yield ChatCompletionResponseStreamChunk( async for chunk in s:
event=ChatCompletionResponseEvent( choice = OpenAICompatCompletionChoice(
event_type=ChatCompletionResponseEventType.start, finish_reason=chunk["done_reason"] if chunk["done"] else None,
delta="", text=chunk["response"],
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
) )
buffer += text yield OpenAICompatCompletionResponse(
continue choices=[choice],
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
) )
yield ChatCompletionResponseStreamChunk( stream = _generate_and_convert_to_openai_compat()
event=ChatCompletionResponseEvent( async for chunk in process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress, request, stream, self.formatter
delta=delta, ):
stop_reason=stop_reason, yield chunk
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together from together import Together
@ -16,9 +16,14 @@ from together import Together
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, chat_completion_request_to_prompt,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -41,8 +46,7 @@ class TogetherInferenceAdapter(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
) )
self.config = config self.config = config
self.tokenizer = Tokenizer.get_instance() self.formatter = ChatFormat(Tokenizer.get_instance())
self.formatter = ChatFormat(self.tokenizer)
@property @property
def client(self) -> Together: def client(self) -> Together:
@ -64,16 +68,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict: def chat_completion(
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
return options
async def chat_completion(
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
@ -84,7 +79,6 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
together_api_key = None together_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
together_api_key = self.config.api_key together_api_key = self.config.api_key
@ -109,148 +103,39 @@ class TogetherInferenceAdapter(
logprobs=logprobs, logprobs=logprobs,
) )
# accumulate sampling params and other options to pass to together if stream:
options = self.get_together_chat_options(request) return self._stream_chat_completion(request, client)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
if not request.stream:
# TODO: might need to add back an async here
r = client.completions.create(
model=together_model,
prompt=prompt,
stream=False,
**options,
)
stop_reason = None
choice = r.choices[0]
if choice.finish_reason:
if choice.finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
stop_reason = StopReason.end_of_turn
elif choice.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
choice.text, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else: else:
yield ChatCompletionResponseStreamChunk( return self._nonstream_chat_completion(request, client)
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = "" async def _nonstream_chat_completion(
ipython = False self, request: ChatCompletionRequest, client: Together
stop_reason = None ) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
for chunk in client.completions.create( async def _stream_chat_completion(
model=together_model, self, request: ChatCompletionRequest, client: Together
prompt=prompt, ) -> AsyncGenerator:
stream=True, params = self._get_params(request)
**options,
):
choice = chunk.choices[0]
if finish_reason := choice.finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = choice.delta.content # if we shift to TogetherAsyncClient, we won't need this wrapper
if text is None: async def _to_async_generator():
continue s = client.completions.create(**params)
for chunk in s:
yield chunk
# check if its a tool call ( aka starts with <|python_tag|> ) stream = _to_async_generator()
if not ipython and text.startswith("<|python_tag|>"): async for chunk in process_chat_completion_stream_response(
ipython = True request, stream, self.formatter
yield ChatCompletionResponseStreamChunk( ):
event=ChatCompletionResponseEvent( yield chunk
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython: def _get_params(self, request: ChatCompletionRequest) -> dict:
if text == "<|eot_id|>": return {
stop_reason = StopReason.end_of_turn "model": self.map_to_provider_model(request.model),
text = "" "prompt": chat_completion_request_to_prompt(request, self.formatter),
continue "stream": request.stream,
elif text == "<|eom_id|>": **get_sampling_options(request),
stop_reason = StopReason.end_of_message }
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -55,7 +55,7 @@ def get_expected_stop_reason(model: str):
@pytest_asyncio.fixture( @pytest_asyncio.fixture(
scope="session", scope="session",
params=[ params=[
{"model": Llama_8B}, # {"model": Llama_8B},
{"model": Llama_3B}, {"model": Llama_3B},
], ],
ids=lambda d: d["model"], ids=lambda d: d["model"],
@ -112,20 +112,16 @@ def sample_tool_definition():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages): async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"] inference_impl = inference_settings["impl"]
response = [ response = await inference_impl.chat_completion(
r messages=sample_messages,
async for r in inference_impl.chat_completion( stream=False,
messages=sample_messages, **inference_settings["common_params"],
stream=False, )
**inference_settings["common_params"],
)
]
assert len(response) == 1 assert isinstance(response, ChatCompletionResponse)
assert isinstance(response[0], ChatCompletionResponse) assert response.completion_message.role == "assistant"
assert response[0].completion_message.role == "assistant" assert isinstance(response.completion_message.content, str)
assert isinstance(response[0].completion_message.content, str) assert len(response.completion_message.content) > 0
assert len(response[0].completion_message.content) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -166,20 +162,16 @@ async def test_chat_completion_with_tool_calling(
) )
] ]
response = [ response = await inference_impl.chat_completion(
r messages=messages,
async for r in inference_impl.chat_completion( tools=[sample_tool_definition],
messages=messages, stream=False,
tools=[sample_tool_definition], **inference_settings["common_params"],
stream=False, )
**inference_settings["common_params"],
)
]
assert len(response) == 1 assert isinstance(response, ChatCompletionResponse)
assert isinstance(response[0], ChatCompletionResponse)
message = response[0].completion_message message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id # This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) # stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3.api.chat_format import ChatFormat
from termcolor import cprint from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
@ -19,6 +20,14 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:
messages = augment_messages_for_tools(request)
model_input = formatter.encode_dialog_prompt(messages)
return formatter.tokenizer.decode(model_input.tokens)
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools. """Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or For eg. for llama_3_1, add system message with the appropriate tools or
@ -48,7 +57,6 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
def augment_messages_for_tools_llama_3_1( def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> List[Message]: ) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages existing_messages = request.messages

View file

@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import * # noqa: F403
from pydantic import BaseModel
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict:
options = {}
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
return options
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
return choice.text
def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)