router method wrapper

This commit is contained in:
Xi Yan 2024-09-21 15:56:20 -07:00
parent 951cc9d7b7
commit 04f480d70c
2 changed files with 43 additions and 12 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple
from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
@ -12,6 +12,10 @@ from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from types import MethodType
from termcolor import cprint
class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type"""
@ -84,9 +88,10 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
provider_impl = self.routing_table.get_provider_impl(self.api, model)
print("InferenceRouter: chat_completion", provider_impl)
async for chunk in provider_impl.chat_completion(
# TODO: we need to fix streaming response to align provider implementations with Protocol
async for chunk in self.routing_table.get_provider_impl(
self.api, model
).chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
@ -97,3 +102,29 @@ class InferenceRouter(Inference):
logprobs=logprobs,
):
yield chunk
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
model=model,
contents=contents,
)

View file

@ -20,14 +20,14 @@ provider_routing_table:
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
- routing_key: Meta-Llama3.1-8B
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
# - routing_key: Meta-Llama3.1-8B
# provider_id: meta-reference
# config:
# model: Meta-Llama3.1-8B
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# memory:
# - routing_key: keyvalue
# provider_id: remote::pgvector