forked from phoenix-oss/llama-stack-mirror
Added support for structured output in the API and added a reference implementation for meta-reference. A few notes: * Two formats are specified in the API: Json schema and EBNF based grammar * Implementation only supports Json for now We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library. Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
200 lines
5.9 KiB
Python
200 lines
5.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import Any, AsyncGenerator, List, Optional
|
|
|
|
import fire
|
|
import httpx
|
|
|
|
from llama_models.llama3.api.datatypes import ImageMedia, URL
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from llama_models.llama3.api import * # noqa: F403
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
from termcolor import cprint
|
|
|
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
|
|
|
from .event_logger import EventLogger
|
|
|
|
|
|
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
|
return InferenceClient(config.url)
|
|
|
|
|
|
def encodable_dict(d: BaseModel):
|
|
return json.loads(d.json())
|
|
|
|
|
|
class InferenceClient(Inference):
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
request = ChatCompletionRequest(
|
|
model=model,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return self._nonstream_chat_completion(request)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> ChatCompletionResponse:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/inference/chat_completion",
|
|
json=encodable_dict(request),
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=20,
|
|
)
|
|
|
|
response.raise_for_status()
|
|
j = response.json()
|
|
return ChatCompletionResponse(**j)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator:
|
|
async with httpx.AsyncClient() as client:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{self.base_url}/inference/chat_completion",
|
|
json=encodable_dict(request),
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=20,
|
|
) as response:
|
|
if response.status_code != 200:
|
|
content = await response.aread()
|
|
cprint(
|
|
f"Error: HTTP {response.status_code} {content.decode()}",
|
|
"red",
|
|
)
|
|
return
|
|
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data:"):
|
|
data = line[len("data: ") :]
|
|
try:
|
|
if "error" in data:
|
|
cprint(data, "red")
|
|
continue
|
|
|
|
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
|
except Exception as e:
|
|
print(data)
|
|
print(f"Error with parsing or validation: {e}")
|
|
|
|
|
|
async def run_main(
|
|
host: str, port: int, stream: bool, model: Optional[str], logprobs: bool
|
|
):
|
|
client = InferenceClient(f"http://{host}:{port}")
|
|
|
|
if not model:
|
|
model = "Llama3.1-8B-Instruct"
|
|
|
|
message = UserMessage(
|
|
content="hello world, write me a 2 sentence poem about the moon"
|
|
)
|
|
cprint(f"User>{message.content}", "green")
|
|
|
|
if logprobs:
|
|
logprobs_config = LogProbConfig(
|
|
top_k=1,
|
|
)
|
|
else:
|
|
logprobs_config = None
|
|
|
|
assert stream, "Non streaming not supported here"
|
|
iterator = await client.chat_completion(
|
|
model=model,
|
|
messages=[message],
|
|
stream=stream,
|
|
logprobs=logprobs_config,
|
|
)
|
|
|
|
if logprobs:
|
|
async for chunk in iterator:
|
|
cprint(f"Response: {chunk}", "red")
|
|
else:
|
|
async for log in EventLogger().log(iterator):
|
|
log.print()
|
|
|
|
|
|
async def run_mm_main(
|
|
host: str, port: int, stream: bool, path: Optional[str], model: Optional[str]
|
|
):
|
|
client = InferenceClient(f"http://{host}:{port}")
|
|
|
|
if not model:
|
|
model = "Llama3.2-11B-Vision-Instruct"
|
|
|
|
message = UserMessage(
|
|
content=[
|
|
ImageMedia(image=URL(uri=f"file://{path}")),
|
|
"Describe this image in two sentences",
|
|
],
|
|
)
|
|
cprint(f"User>{message.content}", "green")
|
|
iterator = client.chat_completion(
|
|
model=model,
|
|
messages=[message],
|
|
stream=stream,
|
|
)
|
|
async for log in EventLogger().log(iterator):
|
|
log.print()
|
|
|
|
|
|
def main(
|
|
host: str,
|
|
port: int,
|
|
stream: bool = True,
|
|
mm: bool = False,
|
|
logprobs: bool = False,
|
|
file: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
):
|
|
if mm:
|
|
asyncio.run(run_mm_main(host, port, stream, file, model))
|
|
else:
|
|
asyncio.run(run_main(host, port, stream, model, logprobs))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|