API Updates: fleshing out RAG APIs, introduce "llama stack" CLI command (#51)

* add tools to chat completion request

* use templates for generating system prompts

* Moved ToolPromptFormat and jinja templates to llama_models.llama3.api

* <WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes

* InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool

* flesh out memory banks API

* agentic loop has a RAG implementation

* faiss provider implementation

* memory client works

* re-work tool definitions, fix FastAPI issues, fix tool regressions

* fix agentic_system utils

* basic RAG seems to work

* small bug fixes for inline attachments

* Refactor custom tool execution utilities

* Bug fix, show memory retrieval steps in EventLogger

* No need for api_key for Remote providers

* add special unicode character ↵ to showcase newlines in model prompt templates

* remove api.endpoints imports

* combine datatypes.py and endpoints.py into api.py

* Attachment / add TTL api

* split batch_inference from inference

* minor import fixes

* use a single impl for ChatFormat.decode_assistant_mesage

* use interleaved_text_media_as_str() utilityt

* Fix api.datatypes imports

* Add blobfile for tiktoken

* Add ToolPromptFormat to ChatFormat.encode_message so that tools are encoded properly

* templates take optional --format={json,function_tag}

* Rag Updates

* Add `api build` subcommand -- WIP

* fix

* build + run image seems to work

* <WIP> adapters

* bunch more work to make adapters work

* api build works for conda now

* ollama remote adapter works

* Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight

* llama distribution -> llama stack + containers (WIP)

* All the new CLI for api + stack work

* Make Fireworks and Together into the Adapter format

* Some quick fixes to the CLI behavior to make it consistent

* Updated README phew

* Update cli_reference.md

* llama_toolchain/distribution -> llama_toolchain/core

* Add termcolor

* update paths

* Add a log just for consistency

* chmod +x scripts

* Fix api dependencies not getting added to configuration

* missing import lol

* Delete utils.py; move to agentic system

* Support downloading of URLs for attachments for code interpreter

* Simplify and generalize `llama api build` yay

* Update `llama stack configure` to be very simple also

* Fix stack start

* Allow building an "adhoc" distribution

* Remote `llama api []` subcommands

* Fixes to llama stack commands and update docs

* Update documentation again and add error messages to llama stack start

* llama stack start -> llama stack run

* Change name of build for less confusion

* Add pyopenapi fork to the repository, update RFC assets

* Remove conflicting annotation

* Added a "--raw" option for model template printing

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
Co-authored-by: Dalton Flanagan <6599399+dltn@users.noreply.github.com>
This commit is contained in:
Ashwin Bharambe 2024-09-03 22:39:39 -07:00 committed by GitHub
parent 35093c0b6f
commit 7bc7785b0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
141 changed files with 8252 additions and 4032 deletions

View file

@ -3,6 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OllamaImplConfig # noqa
from .ollama import get_provider_impl # noqa

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
from .fireworks import FireworksInferenceAdapter
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
import httpx
from fireworks.client import Fireworks
from llama_models.llama3.api.datatypes import (
BuiltinTool,
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
)
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.api import * # noqa: F403
from .config import FireworksImplConfig
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
}
async def get_provider_impl(
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInference(config)
await impl.initialize()
return impl
class FireworksInference(Inference):
class FireworksInferenceAdapter(Inference):
def __init__(self, config: FireworksImplConfig) -> None:
self.config = config

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_toolchain.core.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -4,63 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import OllamaImplConfig
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
async def get_provider_impl(
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
impl = OllamaInference(config)
await impl.initialize()
return impl
class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url)
return AsyncClient(host=self.url)
async def initialize(self) -> None:
try:
@ -111,6 +85,7 @@ class OllamaInference(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.resolve_ollama_model(request.model)
@ -132,7 +107,7 @@ class OllamaInference(Inference):
if not request.stream:
r = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
messages=self._messages_to_ollama_messages(messages),
stream=False,
options=options,
)
@ -143,9 +118,8 @@ class OllamaInference(Inference):
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content(
r["message"]["content"],
stop_reason,
completion_message = self.formatter.decode_assistant_message_from_content(
r["message"]["content"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
@ -160,7 +134,7 @@ class OllamaInference(Inference):
)
stream = await self.client.chat(
model=ollama_model,
messages=self._messages_to_ollama_messages(request.messages),
messages=self._messages_to_ollama_messages(messages),
stream=True,
options=options,
)
@ -228,7 +202,9 @@ class OllamaInference(Inference):
)
# parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason)
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
@ -261,70 +237,3 @@ class OllamaInference(Inference):
stop_reason=stop_reason,
)
)
# TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,
) -> CompletionMessage:
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
return CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
from .together import TogetherInferenceAdapter
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
from llama_models.llama3.api.datatypes import (
BuiltinTool,
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model
from together import Together
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.api import * # noqa: F403
from .config import TogetherImplConfig
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
}
async def get_provider_impl(
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInference(config)
await impl.initialize()
return impl
class TogetherInference(Inference):
class TogetherInferenceAdapter(Inference):
def __init__(self, config: TogetherImplConfig) -> None:
self.config = config

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F401 F403
from .endpoints import * # noqa: F401 F403
from .api import * # noqa: F401 F403

View file

@ -4,17 +4,79 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datatypes import * # noqa: F403
from typing import Optional, Protocol
from enum import Enum
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import webmethod
from typing import List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextAttachment
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
stream: Optional[bool] = False
@ -39,7 +101,7 @@ class CompletionResponseStreamChunk(BaseModel):
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextAttachment]
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
@ -56,7 +118,11 @@ class ChatCompletionRequest(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@ -82,8 +148,11 @@ class BatchChatCompletionRequest(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
available_tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None
@ -92,6 +161,11 @@ class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
@ -105,14 +179,9 @@ class Inference(Protocol):
request: ChatCompletionRequest,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/batch_completion")
async def batch_completion(
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
request: BatchCompletionRequest,
) -> BatchCompletionResponse: ...
@webmethod(route="/inference/batch_chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
) -> BatchChatCompletionResponse: ...
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...

View file

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None

View file

@ -6,12 +6,15 @@
import asyncio
import json
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import fire
import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from .api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -23,13 +26,16 @@ from .api import (
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return InferenceClient(base_url)
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url)
def encodable_dict(d: BaseModel):
return json.loads(d.json())
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
self.base_url = base_url
async def initialize(self) -> None:
@ -46,7 +52,9 @@ class InferenceClient(Inference):
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
data=request.json(),
json={
"request": encodable_dict(request),
},
headers={"Content-Type": "application/json"},
timeout=20,
) as response:

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig # noqa
from .fireworks import get_provider_impl # noqa

View file

@ -5,4 +5,15 @@
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -11,10 +11,10 @@ from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models
from llama_toolchain.inference.api import QuantizationConfig
from pydantic import BaseModel, Field, field_validator
from llama_toolchain.inference.api import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):

View file

@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.sku_list import resolve_model
@ -279,6 +279,7 @@ class Llama:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
if (
max_gen_len is None
@ -288,7 +289,10 @@ class Llama:
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(messages),
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,

View file

@ -6,12 +6,11 @@
import asyncio
from typing import AsyncIterator, Dict, Union
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -22,23 +21,11 @@ from llama_toolchain.inference.api import (
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
@ -67,6 +54,7 @@ class MetaReferenceInferenceImpl(Inference):
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
messages = prepare_messages(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
@ -98,11 +86,12 @@ class MetaReferenceInferenceImpl(Inference):
ipython = False
for token_result in self.generator.chat_completion(
messages=request.messages,
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
buffer += token_result.text
tokens.append(token_result.token)

View file

@ -11,7 +11,7 @@ from functools import partial
from typing import Generator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -27,6 +27,7 @@ class InferenceArgs:
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ModelRunner:
@ -41,6 +42,7 @@ class ModelRunner:
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
@ -93,6 +95,7 @@ class LlamaModelParallelGenerator:
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
@ -100,6 +103,7 @@ class LlamaModelParallelGenerator:
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
tool_prompt_format=tool_prompt_format,
)
gen = self.group.run_inference(req_obj)

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class OllamaImplConfig(BaseModel):
url: str = Field(
default="http://localhost:11434",
description="The URL for the ollama server",
)

View file

@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_models.llama3.prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
)
def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported"
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if request.tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(request.tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if existing_system_message:
# TODO: this fn is needed in many places
def _process(c):
if isinstance(c, str):
return c
else:
return "<media>"
sys_content += "\n"
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
messages.append(UserMessage(content=custom_template.render()))
# Add back existing messages from the request
messages += existing_messages
return messages

View file

@ -6,7 +6,7 @@
from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_toolchain.core.datatypes import * # noqa: F403
def available_inference_providers() -> List[ProviderSpec]:
@ -27,14 +27,13 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
InlineProviderSpec(
remote_provider_spec(
api=Api.inference,
provider_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
adapter=AdapterSpec(
adapter_id="ollama",
pip_packages=["ollama"],
module="llama_toolchain.inference.adapters.ollama",
),
),
InlineProviderSpec(
api=Api.inference,

View file

@ -14,12 +14,12 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from llama_toolchain.inference.api import QuantizationType
from llama_toolchain.inference.api.config import (
CheckpointQuantizationFormat,
MetaReferenceImplConfig,
)
from llama_toolchain.inference.api.datatypes import QuantizationType
from termcolor import cprint
from torch import Tensor

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig # noqa
from .together import get_provider_impl # noqa