diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py index 6f28756c9..d7984f729 100644 --- a/llama_stack/distribution/routers/datasets.py +++ b/llama_stack/distribution/routers/datasets.py @@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO): logger.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.iterrows( dataset_id=dataset_id, start_index=start_index, limit=limit, @@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO): async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.append_rows( dataset_id=dataset_id, rows=rows, ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py index fd0bb90a7..f7a17eecf 100644 --- a/llama_stack/distribution/routers/eval_scoring.py +++ b/llama_stack/distribution/routers/eval_scoring.py @@ -44,7 +44,8 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -66,7 +67,8 @@ class ScoringRouter(Scoring): res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -97,7 +99,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> Job: logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, ) @@ -110,7 +113,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, scoring_functions=scoring_functions, @@ -123,7 +127,8 @@ class EvalRouter(Eval): job_id: str, ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_status(benchmark_id, job_id) async def job_cancel( self, @@ -131,7 +136,8 @@ class EvalRouter(Eval): job_id: str, ) -> None: logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + provider = await self.routing_table.get_provider_impl(benchmark_id) + await provider.job_cancel( benchmark_id, job_id, ) @@ -142,7 +148,8 @@ class EvalRouter(Eval): job_id: str, ) -> EvaluateResponse: logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_result( benchmark_id, job_id, ) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index b39da7810..a5cc8c4b5 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -231,7 +231,7 @@ class InferenceRouter(Inference): logprobs=logprobs, tool_config=tool_config, ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: @@ -292,7 +292,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_chat_completion( model_id=model_id, messages_batch=messages_batch, @@ -322,7 +322,7 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, content=content, @@ -378,7 +378,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) async def embeddings( @@ -395,7 +395,8 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") - return await self.routing_table.get_provider_impl(model_id).embeddings( + provider = await self.routing_table.get_provider_impl(model_id) + return await provider.embeddings( model_id=model_id, contents=contents, text_truncation=text_truncation, @@ -458,7 +459,7 @@ class InferenceRouter(Inference): suffix=suffix, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_completion(**params) async def openai_chat_completion( @@ -538,7 +539,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: response_stream = await provider.openai_chat_completion(**params) if self.store: @@ -575,7 +576,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_embeddings(**params) async def list_chat_completions( diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 15325276f..fd810f8cc 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -117,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable from .models import ModelsRoutingTable diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..5df38ab64 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): tool_to_toolgroup: dict[str, str] = {} # overridden - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? @@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): if routing_key in self.tool_to_toolgroup: routing_key = self.tool_to_toolgroup[routing_key] - return super().get_provider_impl(routing_key, provider_id) + return await super().get_provider_impl(routing_key, provider_id) async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: if toolgroup_id: diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index efe8a98fe..424380324 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -113,7 +113,7 @@ class ProviderSpec(BaseModel): class RoutingTable(Protocol): - def get_provider_impl(self, routing_key: str) -> Any: ... + async def get_provider_impl(self, routing_key: str) -> Any: ... # TODO: this can now be inlined into RemoteProviderSpec