kill older tech-debt, make get_provider_impl async

This commit is contained in:
Ashwin Bharambe 2025-07-22 10:20:04 -07:00
parent 38a9c119df
commit a66074a10e
6 changed files with 31 additions and 21 deletions

View file

@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO):
logger.debug( logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
) )
return await self.routing_table.get_provider_impl(dataset_id).iterrows( provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.iterrows(
dataset_id=dataset_id, dataset_id=dataset_id,
start_index=start_index, start_index=start_index,
limit=limit, limit=limit,
@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO):
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( provider = await self.routing_table.get_provider_impl(dataset_id)
return await provider.append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
) )

View file

@ -44,7 +44,8 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -66,7 +67,8 @@ class ScoringRouter(Scoring):
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score( provider = await self.routing_table.get_provider_impl(fn_identifier)
score_response = await provider.score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -97,7 +99,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}") logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
) )
@ -110,7 +113,8 @@ class EvalRouter(Eval):
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
@ -123,7 +127,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Job: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
self, self,
@ -131,7 +136,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> None: ) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( provider = await self.routing_table.get_provider_impl(benchmark_id)
await provider.job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
) )
@ -142,7 +148,8 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( provider = await self.routing_table.get_provider_impl(benchmark_id)
return await provider.job_result(
benchmark_id, benchmark_id,
job_id, job_id,
) )

View file

@ -231,7 +231,7 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
@ -292,7 +292,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion( return await provider.batch_chat_completion(
model_id=model_id, model_id=model_id,
messages_batch=messages_batch, messages_batch=messages_batch,
@ -322,7 +322,7 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
content=content, content=content,
@ -378,7 +378,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
) )
provider = self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings( async def embeddings(
@ -395,7 +395,8 @@ class InferenceRouter(Inference):
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings( provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
text_truncation=text_truncation, text_truncation=text_truncation,
@ -458,7 +459,7 @@ class InferenceRouter(Inference):
suffix=suffix, suffix=suffix,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params) return await provider.openai_completion(**params)
async def openai_chat_completion( async def openai_chat_completion(
@ -538,7 +539,7 @@ class InferenceRouter(Inference):
user=user, user=user,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if stream:
response_stream = await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(**params)
if self.store: if self.store:
@ -575,7 +576,7 @@ class InferenceRouter(Inference):
user=user, user=user,
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_embeddings(**params) return await provider.openai_embeddings(**params)
async def list_chat_completions( async def list_chat_completions(

View file

@ -117,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable from .models import ModelsRoutingTable

View file

@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_to_toolgroup: dict[str, str] = {} tool_to_toolgroup: dict[str, str] = {}
# overridden # overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
if routing_key in self.tool_to_toolgroup: if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key] routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id) return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id: if toolgroup_id:

View file

@ -113,7 +113,7 @@ class ProviderSpec(BaseModel):
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec # TODO: this can now be inlined into RemoteProviderSpec