forked from phoenix-oss/llama-stack-mirror
Remove request wrapper migration (#64)
* [1/n] migrate inference/chat_completion * migrate inference/completion * inference/completion * inference regenerate openapi spec * safety api * migrate agentic system * migrate apis without implementations * re-generate openapi spec * remove hack from openapi generator * fix inference * fix inference * openapi generator rerun * Simplified Telemetry API and tying it to logger (#57) * Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com> * fix api to work with openapi generator * fix agentic calling inference * together adapter inference * update inference adapters --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com> Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
1d0e91d802
commit
5712566061
26 changed files with 1211 additions and 3031 deletions
|
@ -416,7 +416,16 @@ class AgenticSystem(Protocol):
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
@webmethod(route="/agentic_system/turn/create")
|
||||||
async def create_agentic_system_turn(
|
async def create_agentic_system_turn(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
messages: List[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
) -> AgenticSystemTurnResponseStreamChunk: ...
|
) -> AgenticSystemTurnResponseStreamChunk: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/get")
|
@webmethod(route="/agentic_system/turn/get")
|
||||||
|
|
|
@ -73,9 +73,7 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/agentic_system/turn/create",
|
f"{self.base_url}/agentic_system/turn/create",
|
||||||
json={
|
json=encodable_dict(request),
|
||||||
"request": encodable_dict(request),
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
|
|
@ -388,19 +388,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
tool_calls = []
|
||||||
model=self.agent_config.model,
|
content = ""
|
||||||
messages=input_messages,
|
stop_reason = None
|
||||||
|
async for chunk in self.inference_api.chat_completion(
|
||||||
|
self.agent_config.model,
|
||||||
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=self._get_tools(),
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
):
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
content = ""
|
|
||||||
stop_reason = None
|
|
||||||
async for chunk in self.inference_api.chat_completion(req):
|
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -114,8 +114,26 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
|
|
||||||
async def create_agentic_system_turn(
|
async def create_agentic_system_turn(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
agent_id: str,
|
||||||
|
session_id: str,
|
||||||
|
messages: List[
|
||||||
|
Union[
|
||||||
|
UserMessage,
|
||||||
|
ToolResponseMessage,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
request = AgenticSystemTurnCreateRequest(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=messages,
|
||||||
|
attachments=attachments,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
agent_id = request.agent_id
|
agent_id = request.agent_id
|
||||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
|
|
|
@ -51,11 +51,21 @@ class BatchInference(Protocol):
|
||||||
@webmethod(route="/batch_inference/completion")
|
@webmethod(route="/batch_inference/completion")
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
request: BatchCompletionRequest,
|
model: str,
|
||||||
|
content_batch: List[InterleavedTextMedia],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/batch_inference/chat_completion")
|
@webmethod(route="/batch_inference/chat_completion")
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: BatchChatCompletionRequest,
|
model: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> BatchChatCompletionResponse: ...
|
||||||
|
|
|
@ -46,7 +46,8 @@ class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets/create")
|
@webmethod(route="/datasets/create")
|
||||||
def create_dataset(
|
def create_dataset(
|
||||||
self,
|
self,
|
||||||
request: CreateDatasetRequest,
|
uuid: str,
|
||||||
|
dataset: TrainEvalDataset,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/get")
|
@webmethod(route="/datasets/get")
|
||||||
|
|
|
@ -86,19 +86,19 @@ class Evaluations(Protocol):
|
||||||
@webmethod(route="/evaluate/text_generation/")
|
@webmethod(route="/evaluate/text_generation/")
|
||||||
def evaluate_text_generation(
|
def evaluate_text_generation(
|
||||||
self,
|
self,
|
||||||
request: EvaluateTextGenerationRequest,
|
metrics: List[TextGenerationMetric],
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/question_answering/")
|
@webmethod(route="/evaluate/question_answering/")
|
||||||
def evaluate_question_answering(
|
def evaluate_question_answering(
|
||||||
self,
|
self,
|
||||||
request: EvaluateQuestionAnsweringRequest,
|
metrics: List[QuestionAnsweringMetric],
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/summarization/")
|
@webmethod(route="/evaluate/summarization/")
|
||||||
def evaluate_summarization(
|
def evaluate_summarization(
|
||||||
self,
|
self,
|
||||||
request: EvaluateSummarizationRequest,
|
metrics: List[SummarizationMetric],
|
||||||
) -> EvaluationJob: ...
|
) -> EvaluationJob: ...
|
||||||
|
|
||||||
@webmethod(route="/evaluate/jobs")
|
@webmethod(route="/evaluate/jobs")
|
||||||
|
|
|
@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = list(),
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
messages = prepare_messages(request)
|
messages = prepare_messages(request)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to fireworks
|
# accumulate sampling params and other options to pass to fireworks
|
||||||
|
|
|
@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = list(),
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
messages = prepare_messages(request)
|
messages = prepare_messages(request)
|
||||||
# accumulate sampling params and other options to pass to ollama
|
# accumulate sampling params and other options to pass to ollama
|
||||||
options = self.get_ollama_chat_options(request)
|
options = self.get_ollama_chat_options(request)
|
||||||
|
|
|
@ -82,7 +82,28 @@ class TGIAdapter(Inference):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = list(),
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
messages = prepare_messages(request)
|
messages = prepare_messages(request)
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
prompt = self.tokenizer.decode(model_input.tokens)
|
prompt = self.tokenizer.decode(model_input.tokens)
|
||||||
|
|
|
@ -76,7 +76,29 @@ class TogetherInferenceAdapter(Inference):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = list(),
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to together
|
# accumulate sampling params and other options to pass to together
|
||||||
options = self.get_together_chat_options(request)
|
options = self.get_together_chat_options(request)
|
||||||
together_model = self.resolve_together_model(request.model)
|
together_model = self.resolve_together_model(request.model)
|
||||||
|
|
|
@ -85,6 +85,8 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
|
"""Completion response."""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
@ -108,6 +110,8 @@ class BatchCompletionRequest(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
"""Batch completion response."""
|
||||||
|
|
||||||
completion_message_batch: List[CompletionMessage]
|
completion_message_batch: List[CompletionMessage]
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,6 +141,8 @@ class ChatCompletionResponseStreamChunk(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
"""Chat completion response."""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
|
|
||||||
|
@ -170,13 +176,25 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
model: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
# zero-shot tool definitions as input to the model
|
||||||
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings")
|
@webmethod(route="/inference/embeddings")
|
||||||
|
|
|
@ -10,10 +10,10 @@ from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -52,9 +52,7 @@ class InferenceClient(Inference):
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/inference/chat_completion",
|
f"{self.base_url}/inference/chat_completion",
|
||||||
json={
|
json=encodable_dict(request),
|
||||||
"request": encodable_dict(request),
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
|
|
@ -22,9 +22,12 @@ from llama_toolchain.inference.api import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
|
@ -50,10 +53,30 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||||
# top-level server is going to do. make the typing more specific here
|
# top-level server is going to do. make the typing more specific here
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = list(),
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> AsyncIterator[
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||||
]:
|
]:
|
||||||
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
messages = prepare_messages(request)
|
messages = prepare_messages(request)
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(request.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -179,13 +179,33 @@ class PostTraining(Protocol):
|
||||||
@webmethod(route="/post_training/supervised_fine_tune")
|
@webmethod(route="/post_training/supervised_fine_tune")
|
||||||
def supervised_fine_tune(
|
def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
request: PostTrainingSFTRequest,
|
job_uuid: str,
|
||||||
|
model: str,
|
||||||
|
dataset: TrainEvalDataset,
|
||||||
|
validation_dataset: TrainEvalDataset,
|
||||||
|
algorithm: FinetuningAlgorithm,
|
||||||
|
algorithm_config: Union[
|
||||||
|
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||||
|
],
|
||||||
|
optimizer_config: OptimizerConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/preference_optimize")
|
@webmethod(route="/post_training/preference_optimize")
|
||||||
def preference_optimize(
|
def preference_optimize(
|
||||||
self,
|
self,
|
||||||
request: PostTrainingRLHFRequest,
|
job_uuid: str,
|
||||||
|
finetuned_model: URL,
|
||||||
|
dataset: TrainEvalDataset,
|
||||||
|
validation_dataset: TrainEvalDataset,
|
||||||
|
algorithm: RLHFAlgorithm,
|
||||||
|
algorithm_config: Union[DPOAlignmentConfig],
|
||||||
|
optimizer_config: OptimizerConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post_training/jobs")
|
@webmethod(route="/post_training/jobs")
|
||||||
|
|
|
@ -50,5 +50,6 @@ class RewardScoring(Protocol):
|
||||||
@webmethod(route="/reward_scoring/score")
|
@webmethod(route="/reward_scoring/score")
|
||||||
def reward_score(
|
def reward_score(
|
||||||
self,
|
self,
|
||||||
request: RewardScoringRequest,
|
dialog_generations: List[DialogGenerations],
|
||||||
|
model: str,
|
||||||
) -> Union[RewardScoringResponse]: ...
|
) -> Union[RewardScoringResponse]: ...
|
||||||
|
|
|
@ -86,5 +86,6 @@ class Safety(Protocol):
|
||||||
@webmethod(route="/safety/run_shields")
|
@webmethod(route="/safety/run_shields")
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self,
|
self,
|
||||||
request: RunShieldRequest,
|
messages: List[Message],
|
||||||
|
shields: List[ShieldDefinition],
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -13,10 +13,10 @@ import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import UserMessage
|
from llama_models.llama3.api.datatypes import UserMessage
|
||||||
from pydantic import BaseModel
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
@ -43,9 +43,7 @@ class SafetyClient(Safety):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/safety/run_shields",
|
f"{self.base_url}/safety/run_shields",
|
||||||
json={
|
json=encodable_dict(request),
|
||||||
"request": encodable_dict(request),
|
|
||||||
},
|
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,13 +52,12 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
|
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self,
|
self,
|
||||||
request: RunShieldRequest,
|
messages: List[Message],
|
||||||
|
shields: List[ShieldDefinition],
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
|
shields = [shield_config_to_shield(c, self.config) for c in shields]
|
||||||
|
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
|
||||||
*[shield.run(request.messages) for shield in shields]
|
|
||||||
)
|
|
||||||
|
|
||||||
return RunShieldResponse(responses=responses)
|
return RunShieldResponse(responses=responses)
|
||||||
|
|
||||||
|
|
|
@ -48,5 +48,7 @@ class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic_data_generation/generate")
|
@webmethod(route="/synthetic_data_generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
request: SyntheticDataGenerationRequest,
|
dialogs: List[Message],
|
||||||
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
|
model: Optional[str] = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||||
|
|
|
@ -125,7 +125,7 @@ Event = Annotated[
|
||||||
|
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/log_event")
|
@webmethod(route="/telemetry/log_event")
|
||||||
async def log_event(self, event: Event): ...
|
async def log_event(self, event: Event) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get_trace", method="GET")
|
@webmethod(route="/telemetry/get_trace", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -471,14 +471,9 @@ class Generator:
|
||||||
|
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
|
|
||||||
if len(op.request_params) == 1 and "Request" in first[1].__name__:
|
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
||||||
# TODO(ashwin): Undo the "Request" hack and this entire block eventually
|
request_name = f"{op_name}Request"
|
||||||
request_name = first[1].__name__ + "Wrapper"
|
request_type = make_dataclass(request_name, op.request_params)
|
||||||
request_type = make_dataclass(request_name, op.request_params)
|
|
||||||
else:
|
|
||||||
op_name = "".join(word.capitalize() for word in op.name.split("_"))
|
|
||||||
request_name = f"{op_name}Request"
|
|
||||||
request_type = make_dataclass(request_name, op.request_params)
|
|
||||||
|
|
||||||
requestBody = RequestBody(
|
requestBody = RequestBody(
|
||||||
content={
|
content={
|
||||||
|
|
|
@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
|
||||||
)
|
)
|
||||||
iterator = self.api.chat_completion(request)
|
iterator = self.api.chat_completion(
|
||||||
|
request.model,
|
||||||
|
request.messages,
|
||||||
|
stream=request.stream,
|
||||||
|
tools=request.tools,
|
||||||
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
async for chunk in iterator:
|
async for chunk in iterator:
|
||||||
|
|
|
@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
iterator = self.api.chat_completion(request)
|
iterator = self.api.chat_completion(
|
||||||
|
request.model, request.messages, stream=request.stream
|
||||||
|
)
|
||||||
async for r in iterator:
|
async for r in iterator:
|
||||||
response = r
|
response = r
|
||||||
print(response.completion_message.content)
|
print(response.completion_message.content)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue