add completion() for ollama (#280)

This commit is contained in:
Dinesh Yeduguru 2024-10-21 22:26:33 -07:00 committed by GitHub
parent e2a5a2e10d
commit 1d241bf3fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 138 additions and 15 deletions

View file

@ -23,9 +23,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt,
) )
OLLAMA_SUPPORTED_MODELS = { OLLAMA_SUPPORTED_MODELS = {
@ -93,7 +96,64 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion( async def chat_completion(
self, self,

View file

@ -4,6 +4,10 @@ providers:
config: config:
host: localhost host: localhost
port: 11434 port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi - provider_id: test-tgi
provider_type: remote::tgi provider_type: remote::tgi
config: config:

View file

@ -132,7 +132,10 @@ async def test_completion(inference_settings):
params = inference_settings["common_params"] params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"]) provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_id__ != "meta-reference": if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet") pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion( response = await inference_impl.completion(

View file

@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
if params := request.sampling_params: if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}: for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr): if getattr(params, attr):
if attr == "max_tokens":
options["num_predict"] = getattr(params, attr)
options[attr] = getattr(params, attr) options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0: if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
@ -49,25 +51,35 @@ def text_from_choice(choice) -> str:
return choice.text return choice.text
def get_stop_reason(finish_reason: str) -> StopReason:
if finish_reason in ["stop", "eos"]:
return StopReason.end_of_turn
elif finish_reason == "eom":
return StopReason.end_of_message
elif finish_reason == "length":
return StopReason.out_of_tokens
return StopReason.out_of_tokens
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
)
def process_chat_completion_response( def process_chat_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
choice = response.choices[0] choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content( completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason text_from_choice(choice), get_stop_reason(choice.finish_reason)
) )
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
@ -75,6 +87,43 @@ def process_chat_completion_response(
) )
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
)
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
)
async def process_chat_completion_stream_response( async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -23,6 +23,13 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
def chat_completion_request_to_prompt( def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat request: ChatCompletionRequest, formatter: ChatFormat
) -> str: ) -> str: