router method wrapper

This commit is contained in:
Xi Yan 2024-09-21 15:56:20 -07:00
parent 951cc9d7b7
commit 04f480d70c
2 changed files with 43 additions and 12 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
@ -12,6 +12,10 @@ from .routing_table import RoutingTable
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from types import MethodType
from termcolor import cprint
class MemoryRouter(Memory): class MemoryRouter(Memory):
"""Routes to an provider based on the memory bank type""" """Routes to an provider based on the memory bank type"""
@ -84,9 +88,10 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
provider_impl = self.routing_table.get_provider_impl(self.api, model) # TODO: we need to fix streaming response to align provider implementations with Protocol
print("InferenceRouter: chat_completion", provider_impl) async for chunk in self.routing_table.get_provider_impl(
async for chunk in provider_impl.chat_completion( self.api, model
).chat_completion(
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
@ -97,3 +102,29 @@ class InferenceRouter(Inference):
logprobs=logprobs, logprobs=logprobs,
): ):
yield chunk yield chunk
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(self.api, model).completion(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
model=model,
contents=contents,
)

View file

@ -20,14 +20,14 @@ provider_routing_table:
torch_seed: null torch_seed: null
max_seq_len: 4096 max_seq_len: 4096
max_batch_size: 1 max_batch_size: 1
- routing_key: Meta-Llama3.1-8B # - routing_key: Meta-Llama3.1-8B
provider_id: meta-reference # provider_id: meta-reference
config: # config:
model: Meta-Llama3.1-8B # model: Meta-Llama3.1-8B
quantization: null # quantization: null
torch_seed: null # torch_seed: null
max_seq_len: 4096 # max_seq_len: 4096
max_batch_size: 1 # max_batch_size: 1
# memory: # memory:
# - routing_key: keyvalue # - routing_key: keyvalue
# provider_id: remote::pgvector # provider_id: remote::pgvector