mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
working fireworks and together
This commit is contained in:
parent
25d8ab0e14
commit
8de4cee373
8 changed files with 205 additions and 86 deletions
|
@ -105,9 +105,8 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model.provider_resource_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -132,10 +131,9 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model.provider_resource_id,
|
model_id=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -152,9 +150,8 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model.provider_resource_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
|
||||||
return p.__provider_spec__.api
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
if obj.provider_id == "remote":
|
if obj.provider_id == "remote":
|
||||||
|
@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
obj.provider_id = ""
|
obj.provider_id = ""
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
await p.register_shield(obj)
|
await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
|
@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
assert len(objects) == 1
|
assert len(objects) == 1
|
||||||
return objects[0]
|
return objects[0]
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
async def register_object(
|
||||||
|
self, obj: RoutableObjectWithProvider
|
||||||
|
) -> RoutableObjectWithProvider:
|
||||||
# Get existing objects from registry
|
# Get existing objects from registry
|
||||||
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
||||||
|
|
||||||
|
@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
print(
|
print(
|
||||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||||
)
|
)
|
||||||
return
|
return existing_obj
|
||||||
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
await self.dist_registry.register(obj)
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
|
@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return model
|
return registered_model
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
@ -11,7 +11,10 @@ from botocore.client import BaseClient
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelAlias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
@ -19,19 +22,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
model_aliases = [
|
||||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
ModelAlias(
|
||||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
provider_model_id="meta.llama3-1-8b-instruct-v1:0",
|
||||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
aliases=["Llama3.1-8B"],
|
||||||
}
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta.llama3-1-70b-instruct-v1:0",
|
||||||
|
aliases=["Llama3.1-70B"],
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
aliases=["Llama3.1-405B"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this is not quite tested after the recent refactors
|
# NOTE: this is not quite tested after the recent refactors
|
||||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = create_bedrock_client(config)
|
self._client = create_bedrock_client(config)
|
||||||
|
|
|
@ -37,7 +37,7 @@ DATABRICKS_SUPPORTED_MODELS = {
|
||||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
self, provider_to_common_model_aliases_map=DATABRICKS_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
|
@ -7,14 +7,17 @@
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelAlias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -31,25 +34,61 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
FIREWORKS_SUPPORTED_MODELS = {
|
|
||||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
model_aliases = [
|
||||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
ModelAlias(
|
||||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
provider_model_id="fireworks/llama-v3p1-8b-instruct",
|
||||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
aliases=["Llama3.1-8B-Instruct"],
|
||||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
llama_model=CoreModelId.llama3_1_8b_instruct.value,
|
||||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
),
|
||||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
ModelAlias(
|
||||||
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
|
provider_model_id="fireworks/llama-v3p1-70b-instruct",
|
||||||
}
|
aliases=["Llama3.1-70B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-v3p1-405b-instruct",
|
||||||
|
aliases=["Llama3.1-405B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-v3p2-1b-instruct",
|
||||||
|
aliases=["Llama3.2-1B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-v3p2-3b-instruct",
|
||||||
|
aliases=["Llama3.2-3B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-v3p2-11b-vision-instruct",
|
||||||
|
aliases=["Llama3.2-11B-Vision-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-v3p2-90b-vision-instruct",
|
||||||
|
aliases=["Llama3.2-90B-Vision-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-guard-3-8b",
|
||||||
|
aliases=["Llama-Guard-3-8B"],
|
||||||
|
llama_model=CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="fireworks/llama-guard-3-11b-vision",
|
||||||
|
aliases=["Llama-Guard-3-11B-Vision"],
|
||||||
|
llama_model=CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(
|
class FireworksInferenceAdapter(
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -81,8 +120,9 @@ class FireworksInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -148,8 +188,9 @@ class FireworksInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -207,7 +248,7 @@ class FireworksInferenceAdapter(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -221,7 +262,7 @@ class FireworksInferenceAdapter(
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
|
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
@ -15,7 +17,10 @@ from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelAlias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -33,25 +38,55 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
TOGETHER_SUPPORTED_MODELS = {
|
model_aliases = [
|
||||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
ModelAlias(
|
||||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
provider_model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
aliases=["Llama3.1-8B-Instruct"],
|
||||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
llama_model=CoreModelId.llama3_1_8b_instruct.value,
|
||||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
),
|
||||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
ModelAlias(
|
||||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
provider_model_id="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
aliases=["Llama3.1-70B-Instruct"],
|
||||||
}
|
llama_model=CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||||
|
aliases=["Llama3.1-405B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
|
aliases=["Llama3.2-3B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||||
|
aliases=["Llama3.2-11B-Vision-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||||
|
aliases=["Llama3.2-90B-Vision-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
|
aliases=["Llama-Guard-3-8B"],
|
||||||
|
llama_model=CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id="meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
|
aliases=["Llama-Guard-3-11B-Vision"],
|
||||||
|
llama_model=CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(
|
class TogetherInferenceAdapter(
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -70,8 +105,9 @@ class TogetherInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -145,8 +181,9 @@ class TogetherInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -204,7 +241,7 @@ class TogetherInferenceAdapter(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -213,7 +250,7 @@ class TogetherInferenceAdapter(
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
|
|
|
@ -4,32 +4,54 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from collections import namedtuple
|
||||||
|
from typing import List
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLookup:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_aliases: List[ModelAlias],
|
||||||
|
):
|
||||||
|
self.alias_to_provider_id_map = {}
|
||||||
|
self.provider_id_to_llama_model_map = {}
|
||||||
|
for alias_obj in model_aliases:
|
||||||
|
for alias in alias_obj.aliases:
|
||||||
|
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||||
|
# also add a mapping from provider model id to itself for easy lookup
|
||||||
|
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||||
|
alias_obj.provider_model_id
|
||||||
|
)
|
||||||
|
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||||
|
alias_obj.llama_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_provider_model_id(self, identifier: str) -> str:
|
||||||
|
if identifier in self.alias_to_provider_id_map:
|
||||||
|
return self.alias_to_provider_id_map[identifier]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
|
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
def __init__(self, model_aliases: List[ModelAlias]):
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
self.model_lookup = ModelLookup(model_aliases)
|
||||||
|
|
||||||
def map_to_provider_model(self, identifier: str) -> str:
|
def get_llama_model(self, provider_model_id: str) -> str:
|
||||||
model = resolve_model(identifier)
|
return self.model_lookup.provider_id_to_llama_model_map[provider_model_id]
|
||||||
if not model:
|
|
||||||
raise ValueError(f"Unknown model: `{identifier}`")
|
|
||||||
|
|
||||||
if identifier not in self.stack_to_provider_models_map:
|
async def register_model(self, model: Model) -> Model:
|
||||||
raise ValueError(
|
provider_model_id = self.model_lookup.get_provider_model_id(
|
||||||
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
model.provider_resource_id
|
||||||
)
|
)
|
||||||
|
if not provider_model_id:
|
||||||
|
raise ValueError(f"Unknown model: `{model.provider_resource_id}`")
|
||||||
|
|
||||||
return self.stack_to_provider_models_map[identifier]
|
model.provider_resource_id = provider_model_id
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
return model
|
||||||
if model.provider_resource_id not in self.stack_to_provider_models_map:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
|
||||||
)
|
|
||||||
|
|
|
@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_prompt(
|
def chat_completion_request_to_prompt(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_model_input_info(
|
def chat_completion_request_to_model_input_info(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
|
@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info(
|
||||||
|
|
||||||
def chat_completion_request_to_messages(
|
def chat_completion_request_to_messages(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
llama_model: str,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
"""Reads chat completion request and augments the messages to handle tools.
|
"""Reads chat completion request and augments the messages to handle tools.
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||||
add user messsage for custom tools, etc.
|
add user messsage for custom tools, etc.
|
||||||
"""
|
"""
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(llama_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
cprint(f"Could not resolve model {request.model}", color="red")
|
cprint(f"Could not resolve model {llama_model}", color="red")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
if model.descriptor() not in supported_inference_models():
|
if model.descriptor() not in supported_inference_models():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue