Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -12,7 +12,8 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
@ -21,6 +22,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
@ -28,8 +30,11 @@ FIREWORKS_SUPPORTED_MODELS = {
}
class FireworksInferenceAdapter(Inference):
class FireworksInferenceAdapter(Inference, RoutableProviderForModels):
def __init__(self, config: FireworksImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -65,18 +70,6 @@ class FireworksInferenceAdapter(Inference):
return fireworks_messages
def resolve_fireworks_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in FIREWORKS_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(FIREWORKS_SUPPORTED_MODELS.keys())}"
return FIREWORKS_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -112,7 +105,7 @@ class FireworksInferenceAdapter(Inference):
# accumulate sampling params and other options to pass to fireworks
options = self.get_fireworks_chat_options(request)
fireworks_model = self.resolve_fireworks_model(request.model)
fireworks_model = self.map_to_provider_model(request.model)
if not request.stream:
r = await self.client.chat.completions.acreate(