mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 20:40:00 +00:00
kill older tech-debt, make get_provider_impl async
This commit is contained in:
parent
38a9c119df
commit
a66074a10e
6 changed files with 31 additions and 21 deletions
|
|
@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||||
|
return await provider.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
start_index=start_index,
|
start_index=start_index,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
|
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||||
|
return await provider.append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||||
|
score_response = await provider.score_batch(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
|
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||||
|
score_response = await provider.score(
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
|
@ -97,7 +99,8 @@ class EvalRouter(Eval):
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
|
return await provider.run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
benchmark_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
@ -110,7 +113,8 @@ class EvalRouter(Eval):
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
|
return await provider.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
|
@ -123,7 +127,8 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
|
return await provider.job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
|
@ -131,7 +136,8 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
|
await provider.job_cancel(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
)
|
)
|
||||||
|
|
@ -142,7 +148,8 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
|
return await provider.job_result(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
return await provider.batch_chat_completion(
|
return await provider.batch_chat_completion(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages_batch=messages_batch,
|
messages_batch=messages_batch,
|
||||||
|
|
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
text_truncation=text_truncation,
|
text_truncation=text_truncation,
|
||||||
|
|
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
return await provider.openai_completion(**params)
|
return await provider.openai_completion(**params)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
|
|
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
if stream:
|
if stream:
|
||||||
response_stream = await provider.openai_chat_completion(**params)
|
response_stream = await provider.openai_chat_completion(**params)
|
||||||
if self.store:
|
if self.store:
|
||||||
|
|
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
return await provider.openai_embeddings(**params)
|
return await provider.openai_embeddings(**params)
|
||||||
|
|
||||||
async def list_chat_completions(
|
async def list_chat_completions(
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
from .benchmarks import BenchmarksRoutingTable
|
from .benchmarks import BenchmarksRoutingTable
|
||||||
from .datasets import DatasetsRoutingTable
|
from .datasets import DatasetsRoutingTable
|
||||||
from .models import ModelsRoutingTable
|
from .models import ModelsRoutingTable
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_to_toolgroup: dict[str, str] = {}
|
tool_to_toolgroup: dict[str, str] = {}
|
||||||
|
|
||||||
# overridden
|
# overridden
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
if routing_key in self.tool_to_toolgroup:
|
if routing_key in self.tool_to_toolgroup:
|
||||||
routing_key = self.tool_to_toolgroup[routing_key]
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
return super().get_provider_impl(routing_key, provider_id)
|
return await super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ class ProviderSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable(Protocol):
|
class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
# TODO: this can now be inlined into RemoteProviderSpec
|
# TODO: this can now be inlined into RemoteProviderSpec
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue