forked from phoenix-oss/llama-stack-mirror
Make all methods async def
again; add completion() for meta-reference (#270)
PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def". The rationale was that this allowed the user (within llama-stack) of this to use it as: ``` async for chunk in api.chat_completion(params) ``` However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like: ``` async for chunk in await api.chat_completion(params) ``` Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :)
This commit is contained in:
parent
95a96afe34
commit
2089427d60
23 changed files with 330 additions and 213 deletions
|
@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
|
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
@ -297,15 +296,12 @@ class Llama:
|
|||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def text_completion(
|
||||
def completion(
|
||||
self,
|
||||
content: InterleavedTextMedia,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
@ -313,26 +309,25 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(content)
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
@ -343,12 +338,12 @@ class Llama:
|
|||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_prompt_format,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue