Remove request wrapper migration (#64)

* [1/n] migrate inference/chat_completion

* migrate inference/completion

* inference/completion

* inference regenerate openapi spec

* safety api

* migrate agentic system

* migrate apis without implementations

* re-generate openapi spec

* remove hack from openapi generator

* fix inference

* fix inference

* openapi generator rerun

* Simplified Telemetry API and tying it to logger (#57)

* Simplified Telemetry API and tying it to logger

* small update which adds a METRIC type

* move span events one level down into structured log events

---------

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>

* fix api to work with openapi generator

* fix agentic calling inference

* together adapter inference

* update inference adapters

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Xi Yan 2024-09-12 15:03:49 -07:00 committed by GitHub
parent 1d0e91d802
commit 5712566061
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1211 additions and 3031 deletions

View file

@ -416,7 +416,16 @@ class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")

View file

@ -73,9 +73,7 @@ class AgenticSystemClient(AgenticSystem):
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
json={
"request": encodable_dict(request),
},
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:

View file

@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin):
)
)
req = ChatCompletionRequest(
model=self.agent_config.model,
messages=input_messages,
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
)
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(req):
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue

View file

@ -114,8 +114,26 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgenticSystemTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]

View file

@ -51,11 +51,21 @@ class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion")
async def batch_completion(
self,
request: BatchCompletionRequest,
model: str,
content_batch: List[InterleavedTextMedia],
sampling_params: Optional[SamplingParams] = SamplingParams(),
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion")
async def batch_chat_completion(
self,
request: BatchChatCompletionRequest,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...

View file

@ -46,7 +46,8 @@ class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
request: CreateDatasetRequest,
uuid: str,
dataset: TrainEvalDataset,
) -> None: ...
@webmethod(route="/datasets/get")

View file

@ -86,19 +86,19 @@ class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation(
self,
request: EvaluateTextGenerationRequest,
metrics: List[TextGenerationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering(
self,
request: EvaluateQuestionAnsweringRequest,
metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def evaluate_summarization(
self,
request: EvaluateSummarizationRequest,
metrics: List[SummarizationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")

View file

@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to fireworks

View file

@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)

View file

@ -82,7 +82,28 @@ class TGIAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)

View file

@ -76,7 +76,29 @@ class TogetherInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model)

View file

@ -85,6 +85,8 @@ class CompletionRequest(BaseModel):
@json_schema_type
class CompletionResponse(BaseModel):
"""Completion response."""
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@ -108,6 +110,8 @@ class BatchCompletionRequest(BaseModel):
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
completion_message_batch: List[CompletionMessage]
@ -137,6 +141,8 @@ class ChatCompletionResponseStreamChunk(BaseModel):
@json_schema_type
class ChatCompletionResponse(BaseModel):
"""Chat completion response."""
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@ -170,13 +176,25 @@ class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
request: CompletionRequest,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
@webmethod(route="/inference/embeddings")

View file

@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
import fire
import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import (
ChatCompletionRequest,
@ -52,9 +52,7 @@ class InferenceClient(Inference):
async with client.stream(
"POST",
f"{self.base_url}/inference/chat_completion",
json={
"request": encodable_dict(request),
},
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:

View file

@ -22,9 +22,12 @@ from llama_toolchain.inference.api import (
ToolCallParseStatus,
)
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
@ -50,10 +53,30 @@ class MetaReferenceInferenceImpl(Inference):
# hm, when stream=False, we should not be doing SSE :/ which is what the
# top-level server is going to do. make the typing more specific here
async def chat_completion(
self, request: ChatCompletionRequest
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
model = resolve_model(request.model)
if model is None:

View file

@ -179,13 +179,33 @@ class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune")
def supervised_fine_tune(
self,
request: PostTrainingSFTRequest,
job_uuid: str,
model: str,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: FinetuningAlgorithm,
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize")
def preference_optimize(
self,
request: PostTrainingRLHFRequest,
job_uuid: str,
finetuned_model: URL,
dataset: TrainEvalDataset,
validation_dataset: TrainEvalDataset,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs")

View file

@ -50,5 +50,6 @@ class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def reward_score(
self,
request: RewardScoringRequest,
dialog_generations: List[DialogGenerations],
model: str,
) -> Union[RewardScoringResponse]: ...

View file

@ -86,5 +86,6 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shields")
async def run_shields(
self,
request: RunShieldRequest,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse: ...

View file

@ -13,10 +13,10 @@ import fire
import httpx
from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.core.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint
from .api import * # noqa: F403
@ -43,9 +43,7 @@ class SafetyClient(Safety):
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shields",
json={
"request": encodable_dict(request),
},
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
)

View file

@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shields(
self,
request: RunShieldRequest,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather(
*[shield.run(request.messages) for shield in shields]
)
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
return RunShieldResponse(responses=responses)

View file

@ -48,5 +48,7 @@ class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def synthetic_data_generate(
self,
request: SyntheticDataGenerationRequest,
dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ...

View file

@ -125,7 +125,7 @@ Event = Annotated[
class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event")
async def log_event(self, event: Event): ...
async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -471,11 +471,6 @@ class Generator:
from dataclasses import make_dataclass
if len(op.request_params) == 1 and "Request" in first[1].__name__:
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
request_name = first[1].__name__ + "Wrapper"
request_type = make_dataclass(request_name, op.request_params)
else:
op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params)

View file

@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
stream=True,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
)
iterator = self.api.chat_completion(request)
iterator = self.api.chat_completion(
request.model,
request.messages,
stream=request.stream,
tools=request.tools,
)
events = []
async for chunk in iterator:

View file

@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
],
stream=False,
)
iterator = self.api.chat_completion(request)
iterator = self.api.chat_completion(
request.model, request.messages, stream=request.stream
)
async for r in iterator:
response = r
print(response.completion_message.content)