mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 09:58:10 +00:00
Introduce Llama stack distributions (#22)
* Add distribution CLI scaffolding * More progress towards `llama distribution install` * getting closer to a distro definition, distro install + configure works * Distribution server now functioning * read existing configuration, save enums properly * Remove inference uvicorn server entrypoint and llama inference CLI command * updated dependency and client model name * Improved exception handling * local imports for faster cli * undo a typo, add a passthrough distribution * implement full-passthrough in the server * add safety adapters, configuration handling, server + clients * cleanup, moving stuff to common, nuke utils * Add a Path() wrapper at the earliest place * fixes * Bring agentic system api to toolchain Add adapter dependencies and resolve adapters using a topological sort * refactor to reduce size of `agentic_system` * move straggler files and fix some important existing bugs * ApiSurface -> Api * refactor a method out * Adapter -> Provider * Make each inference provider into its own subdirectory * installation fixes * Rename Distribution -> DistributionSpec, simplify RemoteProviders * dict key instead of attr * update inference config to take model and not model_dir * Fix passthrough streaming, send headers properly not part of body :facepalm * update safety to use model sku ids and not model dirs * Update cli_reference.md * minor fixes * add DistributionConfig, fix a bug in model download * Make install + start scripts do proper configuration automatically * Update CLI_reference * Nuke fp8_requirements, fold fbgemm into common requirements * Update README, add newline between API surface configurations * Refactor download functionality out of the Command so can be reused * Add `llama model download` alias for `llama download` * Show message about checksum file so users can check themselves * Simpler intro statements * get ollama working * Reduce a bunch of dependencies from toolchain Some improvements to the distribution install script * Avoid using `conda run` since it buffers everything * update dependencies and rely on LLAMA_TOOLCHAIN_DIR for dev purposes * add validation for configuration input * resort imports * make optional subclasses default to yes for configuration * Remove additional_pip_packages; move deps to providers * for inline make 8b model the default * Add scripts to MANIFEST * allow installing from test.pypi.org * Fix #2 to help with testing packages * Must install llama-models at that same version first * fix PIP_ARGS --------- Co-authored-by: Hardik Shah <hjshah@fb.com> Co-authored-by: Hardik Shah <hjshah@meta.com>
This commit is contained in:
parent
da4645a27a
commit
e830814399
115 changed files with 5839 additions and 1120 deletions
|
@ -1,102 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from hydra_zen import builds
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImplType(Enum):
|
||||
inline = "inline"
|
||||
remote = "remote"
|
||||
ollama = "ollama"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
huggingface = "huggingface"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PytorchCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.pytorch.value] = (
|
||||
CheckpointType.pytorch.value
|
||||
)
|
||||
checkpoint_dir: str
|
||||
tokenizer_path: str
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class HuggingFaceCheckpoint(BaseModel):
|
||||
checkpoint_type: Literal[CheckpointType.huggingface.value] = (
|
||||
CheckpointType.huggingface.value
|
||||
)
|
||||
repo_id: str # or model_name ?
|
||||
model_parallel_size: int
|
||||
quantization_format: CheckpointQuantizationFormat = (
|
||||
CheckpointQuantizationFormat.bf16
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCheckpointConfig(BaseModel):
|
||||
checkpoint: Annotated[
|
||||
Union[PytorchCheckpoint, HuggingFaceCheckpoint],
|
||||
Field(discriminator="checkpoint_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
||||
url: str = Field(..., description="The URL of the remote module")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OllamaImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
|
||||
model: str = Field(..., description="The name of the model in ollama catalog")
|
||||
url: str = Field(..., description="The URL for the ollama server")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceConfig(BaseModel):
|
||||
impl_config: Annotated[
|
||||
Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
|
||||
Field(discriminator="impl_type"),
|
||||
]
|
||||
|
||||
|
||||
InferenceHydraConfig = builds(InferenceConfig)
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
|
@ -7,9 +7,9 @@
|
|||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from strong_typing.schema import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
||||
|
|
|
@ -8,7 +8,7 @@ from .datatypes import * # noqa: F403
|
|||
from typing import Optional, Protocol
|
||||
|
||||
# this dependency is annoying and we need a forked up version anyway
|
||||
from pyopenapi import webmethod
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api.config import ImplType, InferenceConfig
|
||||
|
||||
|
||||
async def get_inference_api_instance(config: InferenceConfig):
|
||||
if config.impl_config.impl_type == ImplType.inline.value:
|
||||
from .inference import InferenceImpl
|
||||
|
||||
return InferenceImpl(config.impl_config)
|
||||
elif config.impl_config.impl_type == ImplType.ollama.value:
|
||||
from .ollama import OllamaInference
|
||||
|
||||
return OllamaInference(config.impl_config)
|
||||
|
||||
from .client import InferenceClient
|
||||
|
||||
return InferenceClient(config.impl_config.url)
|
|
@ -23,6 +23,10 @@ from .api import (
|
|||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return InferenceClient(base_url)
|
||||
|
||||
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Initializing client for {base_url}")
|
||||
|
@ -46,12 +50,25 @@ class InferenceClient(Inference):
|
|||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
content = await response.aread()
|
||||
cprint(
|
||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
||||
)
|
||||
return
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
except Exception as e:
|
||||
|
@ -62,11 +79,11 @@ class InferenceClient(Inference):
|
|||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
message = UserMessage(content="hello world, help me out here")
|
||||
message = UserMessage(content="hello world, troll me in two-paragraphs about 42")
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="Meta-Llama-3.1-8B-Instruct",
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
)
|
||||
|
|
|
@ -1,161 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
|
||||
from .api.config import InlineImplConfig
|
||||
from .api.datatypes import (
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from .api.endpoints import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
)
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
class InferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: InlineImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=request.messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
8
llama_toolchain/inference/meta_reference/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from .inference import get_provider_impl # noqa
|
43
llama_toolchain/inference/meta_reference/config.py
Normal file
43
llama_toolchain/inference/meta_reference/config.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_toolchain.inference.api import QuantizationConfig
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int
|
||||
max_batch_size: int = 1
|
||||
|
||||
@validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
permitted_models = [
|
||||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
raise ValueError(
|
||||
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
||||
)
|
||||
return model
|
|
@ -25,12 +25,27 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
from llama_models.llama3_1.api.args import ModelArgs
|
||||
from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.model import Transformer
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3_1.reference_impl.model import Transformer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from .api.config import CheckpointType, InlineImplConfig
|
||||
from .api.datatypes import QuantizationType
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoint dir: {checkpoint_dir}."
|
||||
f"Please download model using `llama download {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -42,7 +57,7 @@ class TokenResult:
|
|||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: InlineImplConfig):
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
|
@ -50,9 +65,7 @@ class Llama:
|
|||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
|
@ -66,7 +79,7 @@ class Llama:
|
|||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = checkpoint.model_parallel_size
|
||||
model_parallel_size = model.hardware_requirements.gpu_count
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
|
@ -81,7 +94,8 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = checkpoint.checkpoint_dir
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
|
@ -102,7 +116,9 @@ class Llama:
|
|||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
204
llama_toolchain/inference/meta_reference/inference.py
Normal file
204
llama_toolchain/inference/meta_reference/inference.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=request.messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
logprobs.append(token_result.logprob)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
content=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.failure,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=tool_call,
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
@ -12,9 +13,10 @@ from typing import Generator, List, Optional
|
|||
from llama_models.llama3_1.api.chat_format import ChatFormat
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .api.config import InlineImplConfig
|
||||
from .generation import Llama
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -42,7 +44,7 @@ class ModelRunner:
|
|||
)
|
||||
|
||||
|
||||
def init_model_cb(config: InlineImplConfig):
|
||||
def init_model_cb(config: MetaReferenceImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
@ -58,13 +60,14 @@ class LlamaModelParallelGenerator:
|
|||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: InlineImplConfig):
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
|
||||
self.model = resolve_model(self.config.model)
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.formatter = ChatFormat(Tokenizer(checkpoint.tokenizer_path))
|
||||
checkpoint_dir = model_checkpoint_dir(self.model)
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -73,9 +76,8 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
checkpoint = self.config.checkpoint_config.checkpoint
|
||||
self.group = ModelParallelProcessGroup(
|
||||
checkpoint.model_parallel_size,
|
||||
self.model.hardware_requirements.gpu_count,
|
||||
init_model_cb=partial(init_model_cb, self.config),
|
||||
)
|
||||
self.group.start()
|
8
llama_toolchain/inference/ollama/__init__.py
Normal file
8
llama_toolchain/inference/ollama/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OllamaImplConfig # noqa
|
||||
from .ollama import get_provider_impl # noqa
|
16
llama_toolchain/inference/ollama/config.py
Normal file
16
llama_toolchain/inference/ollama/config.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434",
|
||||
description="The URL for the ollama server",
|
||||
)
|
|
@ -1,11 +1,14 @@
|
|||
import httpx
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Dict
|
||||
|
||||
from typing import AsyncGenerator
|
||||
import httpx
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
|
@ -14,44 +17,56 @@ from llama_models.llama3_1.api.datatypes import (
|
|||
ToolCall,
|
||||
)
|
||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
||||
|
||||
from .api.config import OllamaImplConfig
|
||||
from .api.datatypes import (
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from .api.endpoints import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from ollama import AsyncClient
|
||||
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
|
||||
# TODO: Add other variants for llama3.1
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
) -> Inference:
|
||||
assert isinstance(
|
||||
config, OllamaImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = OllamaInference(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.config = config
|
||||
self.model = config.model
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.config.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = AsyncClient(host=self.config.url)
|
||||
try:
|
||||
status = await self.client.pull(self.model)
|
||||
assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama"
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError:
|
||||
print("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
raise
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -62,17 +77,19 @@ class OllamaInference(Inference):
|
|||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
ollama_messages.append(
|
||||
{"role": message.role, "content": message.content}
|
||||
)
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
ollama_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
def resolve_ollama_model(self, model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None and
|
||||
model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
|
||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
|
||||
|
||||
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
|
||||
|
@ -84,8 +101,8 @@ class OllamaInference(Inference):
|
|||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None and
|
||||
request.sampling_params.repetition_penalty != 1.0
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
|
@ -95,6 +112,21 @@ class OllamaInference(Inference):
|
|||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.resolve_ollama_model(request.model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=ollama_model,
|
||||
|
@ -103,14 +135,14 @@ class OllamaInference(Inference):
|
|||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r['done']:
|
||||
if r['done_reason'] == 'stop':
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r['done_reason'] == 'length':
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = decode_assistant_message_from_content(
|
||||
r['message']['content'],
|
||||
r["message"]["content"],
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
|
@ -124,7 +156,6 @@ class OllamaInference(Inference):
|
|||
delta="",
|
||||
)
|
||||
)
|
||||
|
||||
stream = await self.client.chat(
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(request.messages),
|
||||
|
@ -137,15 +168,14 @@ class OllamaInference(Inference):
|
|||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
# check if ollama is done
|
||||
if chunk['done']:
|
||||
if chunk['done_reason'] == 'stop':
|
||||
if chunk["done"]:
|
||||
if stop_reason is None and chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif chunk['done_reason'] == 'length':
|
||||
elif stop_reason is None and chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk['message']['content']
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
|
@ -159,7 +189,7 @@ class OllamaInference(Inference):
|
|||
),
|
||||
)
|
||||
)
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
buffer += text
|
||||
continue
|
||||
|
||||
if ipython:
|
||||
|
@ -197,7 +227,6 @@ class OllamaInference(Inference):
|
|||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message_from_content(buffer, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -232,7 +261,7 @@ class OllamaInference(Inference):
|
|||
)
|
||||
|
||||
|
||||
#TODO: Consolidate this with impl in llama-models
|
||||
# TODO: Consolidate this with impl in llama-models
|
||||
def decode_assistant_message_from_content(
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
39
llama_toolchain/inference/providers.py
Normal file
39
llama_toolchain/inference/providers.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_id="meta-reference",
|
||||
pip_packages=[
|
||||
"accelerate",
|
||||
"blobfile",
|
||||
"codeshield",
|
||||
"fairscale",
|
||||
"fbgemm-gpu==0.8.0",
|
||||
"torch",
|
||||
"transformers",
|
||||
"zmq",
|
||||
],
|
||||
module="llama_toolchain.inference.meta_reference",
|
||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_id="meta-ollama",
|
||||
pip_packages=[
|
||||
"ollama",
|
||||
],
|
||||
module="llama_toolchain.inference.ollama",
|
||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
||||
),
|
||||
]
|
|
@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
|||
|
||||
from llama_toolchain.inference.api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
InlineImplConfig,
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
||||
|
||||
|
@ -46,7 +46,7 @@ def swiglu_wrapper(
|
|||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: InlineImplConfig,
|
||||
config: MetaReferenceImplConfig,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
import fire
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from hydra_zen import instantiate
|
||||
|
||||
from llama_toolchain.utils import get_default_config_dir, parse_config
|
||||
from .api.endpoints import ChatCompletionRequest, ChatCompletionResponseStreamChunk
|
||||
|
||||
from .api_instance import get_inference_api_instance
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
GLOBAL_CONFIG = None
|
||||
|
||||
|
||||
def get_config():
|
||||
return GLOBAL_CONFIG
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global InferenceApiInstance
|
||||
|
||||
config = get_config()
|
||||
|
||||
inference_config = instantiate(config["inference_config"])
|
||||
InferenceApiInstance = await get_inference_api_instance(
|
||||
inference_config,
|
||||
)
|
||||
await InferenceApiInstance.initialize()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
global InferenceApiInstance
|
||||
|
||||
print("shutting down")
|
||||
await InferenceApiInstance.shutdown()
|
||||
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/inference/chat_completion", response_model=ChatCompletionResponseStreamChunk
|
||||
)
|
||||
def chat_completion(request: Request, exec_request: ChatCompletionRequest):
|
||||
if semaphore.locked():
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Only a single concurrent request allowed right now.",
|
||||
)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for event in event_gen:
|
||||
yield f"data: {event.json()}\n\n"
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
finally:
|
||||
semaphore.release()
|
||||
|
||||
async def event_gen():
|
||||
async for event in InferenceApiInstance.chat_completion(exec_request):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(event_gen()),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def main(config_path: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
global GLOBAL_CONFIG
|
||||
config_dir = get_default_config_dir()
|
||||
GLOBAL_CONFIG = parse_config(config_dir, config_path)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
Loading…
Add table
Add a link
Reference in a new issue