Merge branch 'meta-llama:main' into qdrant

This commit is contained in:
Anush 2024-10-22 21:45:31 +05:30 committed by GitHub
commit 1575578446
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 3310 additions and 722 deletions

View file

@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
self.client.close()
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
return tool_config
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],

View file

@ -7,10 +7,11 @@
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl
return impl

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
api_token: str = Field(
default=None,
description="The Databricks API token",
)
)

View file

@ -48,10 +48,17 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -77,14 +84,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
@ -98,7 +105,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -87,14 +87,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
@ -103,7 +103,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -23,9 +23,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
OLLAMA_SUPPORTED_MODELS = {
@ -33,7 +36,8 @@ OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
"Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b",
}
@ -84,7 +88,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return ret
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -92,9 +96,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def chat_completion(
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -118,7 +179,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
@ -143,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -163,7 +224,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
@ -116,7 +116,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -135,7 +135,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,14 +101,14 @@ class TogetherInferenceAdapter(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together
@ -123,7 +123,7 @@ class TogetherInferenceAdapter(
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VLLMImplConfig
from .vllm import VLLMInferenceAdapter
async def get_adapter_impl(config: VLLMImplConfig, _deps):
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class VLLMImplConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
api_token: Optional[str] = Field(
default=None,
description="The API token",
)

View file

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMImplConfig
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None:
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for vLLM models")
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(identifier=model.id, llama_model=model.id)
for model in self.client.models.list()
]
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request, self.client)
else:
return self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),

View file

@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
)
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -17,13 +17,22 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:

View file

@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@ -98,7 +97,10 @@ class Llama:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = model_checkpoint_dir(model)
if config.checkpoint_dir:
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -119,9 +121,7 @@ class Llama:
**params,
)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
@ -138,7 +138,7 @@ class Llama:
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
model = convert_to_quantized_model(model, config, ckpt_dir)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -170,14 +170,16 @@ class Llama:
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -228,8 +230,7 @@ class Llama:
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
@ -295,15 +296,12 @@ class Llama:
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
request: CompletionRequest,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -311,26 +309,25 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
echo=False,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -341,12 +338,12 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
)

View file

@ -13,11 +13,9 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -36,8 +34,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
@ -51,9 +52,21 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
]
async def shutdown(self) -> None:
self.generator.stop()
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -61,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
def chat_completion(
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -88,43 +206,26 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
def impl():
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
@ -154,12 +255,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -172,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
@ -272,6 +370,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def embeddings(
self,
model: str,

View file

@ -7,16 +7,17 @@
import os
from copy import deepcopy
from functools import partial
from typing import Generator, List, Optional
from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
from .parallel_utils import ModelParallelProcessGroup
class ModelRunner:
@ -24,15 +25,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig):
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
def completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: CompletionRequest,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format,
)
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen
def chat_completion(
self,
request: ChatCompletionRequest,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import multiprocessing
import os
@ -11,10 +17,9 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
self.process.join()
self.started = False
def run_inference(self, inference_args: InferenceArgs) -> Generator:
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj_json = self.request_socket.recv()

View file

@ -13,9 +13,10 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor
@ -39,6 +40,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
@ -49,12 +51,14 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path

View file

@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
async for chunk in await self.inference_api.chat_completion(
model=self.model,
messages=[shield_input_message],
stream=True,

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig

View file

@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if self.engine:
self.engine.shutdown_background_loop()
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
logprobs=logprobs,
)
def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[Message],
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, results_generator)
else:
return self._nonstream_chat_completion(request, results_generator)
return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
@ -207,7 +207,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
@ -229,7 +229,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -55,11 +55,20 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama"],
pip_packages=["ollama", "aiohttp"],
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.adapters.inference.ollama",
),
),
# remote_provider_spec(
# api=Api.inference,
# adapter=AdapterSpec(
# adapter_type="vllm",
# pip_packages=["openai"],
# module="llama_stack.providers.adapters.inference.vllm",
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
# ),
# ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -31,4 +31,4 @@ providers:
persistence_store:
namespace: null
type: sqlite
db_path: /Users/ashwin/.llama/runtime/kvstore.db
db_path: ~/.llama/runtime/kvstore.db

View file

@ -64,6 +64,24 @@ def search_query_messages():
]
@pytest.fixture
def attachment_message():
return [
UserMessage(
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
),
]
@pytest.fixture
def query_attachment_messages():
return [
UserMessage(
content="What are the top 5 topics that were explained? Only list succinct bullet points."
),
]
@pytest.mark.asyncio
async def test_create_agent_turn(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
@ -98,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
assert len(final_event.turn.output_message.content) > 0
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
agents_settings, attachment_message, query_attachment_messages
):
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
agents_impl = agents_settings["impl"]
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[
MemoryToolDefinition(
memory_bank_configs=[],
query_generator_config={
"type": "default",
"sep": " ",
},
max_tokens_in_context=4096,
max_chunks=10,
),
],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
# Create a second turn querying the agent
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=query_attachment_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
agents_settings, search_query_messages
@ -169,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0

View file

@ -4,6 +4,10 @@ providers:
config:
host: localhost
port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi
provider_type: remote::tgi
config:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import itertools
import os
import pytest
import pytest_asyncio
@ -50,14 +51,17 @@ def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
if "MODEL_IDS" not in os.environ:
MODEL_IDS = [Llama_8B, Llama_3B]
else:
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
params=[{"model": m} for m in MODEL_IDS],
ids=lambda d: d["model"],
)
async def inference_settings(request):
@ -122,6 +126,48 @@ async def test_model_list(inference_settings):
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Roses are red,",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "violets are blue" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) == 51
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
@ -142,7 +188,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
async for r in await inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
@ -213,7 +259,7 @@ async def test_chat_completion_with_tool_calling_streaming(
response = [
r
async for r in inference_impl.chat_completion(
async for r in await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,

View file

@ -2,8 +2,8 @@ providers:
- provider_id: test-faiss
provider_type: meta-reference
config: {}
- provider_id: test-chroma
provider_type: remote::chroma
- provider_id: test-chromadb
provider_type: remote::chromadb
config:
host: localhost
port: 6001

View file

@ -89,6 +89,30 @@ async def test_banks_list(memory_settings):
assert len(response) == 0
@pytest.mark.asyncio
async def test_banks_register(memory_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"]
bank = VectorMemoryBankDef(
identifier="test_bank_no_provider",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]

View file

@ -14,7 +14,7 @@ import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
@ -36,7 +36,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
impls = await resolve_impls(run_config)
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id

View file

@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
if attr == "max_tokens":
options["num_predict"] = getattr(params, attr)
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
@ -49,27 +51,35 @@ def text_from_choice(choice) -> str:
return choice.text
def get_stop_reason(finish_reason: str) -> StopReason:
if finish_reason in ["stop", "eos"]:
return StopReason.end_of_turn
elif finish_reason == "eom":
return StopReason.end_of_message
elif finish_reason == "length":
return StopReason.out_of_tokens
return StopReason.out_of_tokens
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
)
def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
return ChatCompletionResponse(
completion_message=completion_message,
@ -77,10 +87,45 @@ def process_chat_completion_response(
)
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
)
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
)
async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -23,6 +23,13 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:

View file

@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity:
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARNING
return LogSeverity.WARN
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":