llama-stack-mirror/llama_stack/distribution/routers/routers.py
Xi Yan 6192bf43a4
[Evals API][10/n] API updates for EvalTaskDef + new test migration (#379)
* wip

* scoring fn api

* eval api

* eval task

* evaluate api update

* pre commit

* unwrap context -> config

* config field doc

* typo

* naming fix

* separate benchmark / app eval

* api name

* rename

* wip tests

* wip

* datasetio test

* delete unused

* fixture

* scoring resolve

* fix scoring register

* scoring test pass

* score batch

* scoring fix

* fix eval

* test eval works

* remove type ignore

* api refactor

* add default task_eval_id for routing

* add eval_id for jobs

* remove type ignore

* only keep 1 run_eval

* fix optional

* register task required

* register task required

* delete old tests

* delete old tests

* fixture return impl
2024-11-07 21:24:12 -08:00

319 lines
9.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List
from llama_stack.apis.datasetio.datasetio import DatasetIO
from llama_stack.distribution.datatypes import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
await self.routing_table.register_memory_bank(memory_bank)
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.routing_table.get_provider_impl(bank_id).insert_documents(
bank_id, documents, ttl_seconds
)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
bank_id, query, params
)
class InferenceRouter(Inference):
"""Routes to an provider based on the model"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
params = dict(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return await provider.chat_completion(**params)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in await provider.completion(**params))
else:
return await provider.completion(**params)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
contents=contents,
)
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
await self.routing_table.register_shield(shield)
async def run_shield(
self,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(identifier).run_shield(
identifier=identifier,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)