introduce openai_compat with the completions (not chat-completions) API

This keeps the prompt encoding layer in our control (see
`chat_completion_request_to_prompt()` method)
This commit is contained in:
Ashwin Bharambe 2024-10-08 12:15:55 -07:00 committed by Ashwin Bharambe
parent 0c9eb3341c
commit 05e73d12b3
6 changed files with 354 additions and 513 deletions

View file

@ -9,17 +9,22 @@ from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
@ -30,14 +35,10 @@ OLLAMA_SUPPORTED_MODELS = {
}
class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
)
self.url = url
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.formatter = ChatFormat(Tokenizer.get_instance())
@property
def client(self) -> AsyncClient:
@ -55,6 +56,28 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
async def register_model(self, model: ModelDef) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
)
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
def completion(
self,
model: str,
@ -65,20 +88,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(request.sampling_params, attr):
options[attr] = getattr(request.sampling_params, attr)
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
options["repeat_penalty"] = request.sampling_params.repetition_penalty
return options
def chat_completion(
self,
model: str,
@ -90,22 +99,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
ollama_model = self.map_to_provider_model(model)
res = await self.client.ps()
need_model_pull = True
for r in res["models"]:
if ollama_model == r["model"]:
need_model_pull = False
break
if need_model_pull:
print(f"Pulling model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
request = ChatCompletionRequest(
model=model,
messages=messages,
@ -116,24 +109,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"options": options,
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"raw": True,
"stream": request.stream,
}
@ -143,129 +128,35 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
assert isinstance(r, dict)
completion_message = self.formatter.decode_assistant_message_from_content(
r["response"], stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.generate(**params)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
yield OpenAICompatCompletionResponse(
choices=[choice],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk