mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge f0211ffb70
into sapling-pr-archive-ehhuang
This commit is contained in:
commit
953f51f87a
11 changed files with 808 additions and 1132 deletions
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1
|
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
bubblewrap \
|
||||||
gcc \
|
gcc g++ \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
try:
|
try:
|
||||||
models = await provider.list_models()
|
models = await provider.list_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.listed_providers.add(provider_id)
|
self.listed_providers.add(provider_id)
|
||||||
|
|
|
@ -4,11 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
|
||||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
return config_api_key
|
return config_api_key
|
||||||
|
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
)
|
)
|
||||||
return provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
def _get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return "https://api.fireworks.ai/inference/v1"
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_client(self) -> Fireworks:
|
||||||
fireworks_api_key = self._get_api_key()
|
fireworks_api_key = self.get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
|
||||||
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
|
"""Remove BOS token as Fireworks automatically prepends it"""
|
||||||
|
if prompt.startswith("<|begin_of_text|>"):
|
||||||
|
return prompt[len("<|begin_of_text|>") :]
|
||||||
|
return prompt
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
|
|
||||||
embeddings = [data.embedding for data in response.data]
|
embeddings = [data.embedding for data in response.data]
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
|
||||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
|
||||||
prompt = prompt[len("<|begin_of_text|>") :]
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self._get_openai_client().completions.create(**params)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
# Divert Llama Models through Llama Stack inference APIs because
|
|
||||||
# Fireworks chat completions OpenAI-compatible API does not support
|
|
||||||
# tool calls properly.
|
|
||||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
|
||||||
|
|
||||||
if llama_model:
|
|
||||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"fireworks params: {params}")
|
|
||||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
|
||||||
|
|
|
@ -4,15 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import GeminiConfig
|
from .config import GeminiConfig
|
||||||
|
|
||||||
|
|
||||||
class GeminiProviderDataValidator(BaseModel):
|
|
||||||
gemini_api_key: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||||
from .gemini import GeminiInferenceAdapter
|
from .gemini import GeminiInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -4,15 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProviderDataValidator(BaseModel):
|
|
||||||
openai_api_key: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||||
from .openai import OpenAIInferenceAdapter
|
from .openai import OpenAIInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
Model(
|
Model(
|
||||||
identifier=id,
|
identifier=id,
|
||||||
provider_resource_id=entry.provider_model_id,
|
provider_resource_id=entry.provider_model_id,
|
||||||
model_type=ModelType.llm,
|
model_type=entry.model_type,
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
provider_id=self.__provider_id__,
|
provider_id=self.__provider_id__,
|
||||||
)
|
)
|
||||||
|
|
1724
llama_stack/ui/package-lock.json
generated
1724
llama_stack/ui/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
@ -14,7 +14,7 @@
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@radix-ui/react-collapsible": "^1.1.12",
|
"@radix-ui/react-collapsible": "^1.1.12",
|
||||||
"@radix-ui/react-dialog": "^1.1.13",
|
"@radix-ui/react-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||||
"@radix-ui/react-select": "^2.2.6",
|
"@radix-ui/react-select": "^2.2.6",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
"@radix-ui/react-separator": "^1.1.7",
|
||||||
|
@ -32,7 +32,7 @@
|
||||||
"react-dom": "^19.1.1",
|
"react-dom": "^19.1.1",
|
||||||
"react-markdown": "^10.1.0",
|
"react-markdown": "^10.1.0",
|
||||||
"remark-gfm": "^4.0.1",
|
"remark-gfm": "^4.0.1",
|
||||||
"remeda": "^2.30.0",
|
"remeda": "^2.32.0",
|
||||||
"shiki": "^1.29.2",
|
"shiki": "^1.29.2",
|
||||||
"sonner": "^2.0.7",
|
"sonner": "^2.0.7",
|
||||||
"tailwind-merge": "^3.3.1"
|
"tailwind-merge": "^3.3.1"
|
||||||
|
@ -52,7 +52,7 @@
|
||||||
"eslint-config-prettier": "^10.1.8",
|
"eslint-config-prettier": "^10.1.8",
|
||||||
"eslint-plugin-prettier": "^5.5.4",
|
"eslint-plugin-prettier": "^5.5.4",
|
||||||
"jest": "^29.7.0",
|
"jest": "^29.7.0",
|
||||||
"jest-environment-jsdom": "^29.7.0",
|
"jest-environment-jsdom": "^30.1.2",
|
||||||
"prettier": "3.6.2",
|
"prettier": "3.6.2",
|
||||||
"tailwindcss": "^4",
|
"tailwindcss": "^4",
|
||||||
"ts-node": "^10.9.2",
|
"ts-node": "^10.9.2",
|
||||||
|
|
|
@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # service returns 400
|
"remote::together", # service returns 400
|
||||||
|
"remote::fireworks", # service returns 400 malformed input
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
||||||
provider = provider_from_model(client, model_id)
|
provider = provider_from_model(client, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
"remote::together", # param silently ignored, always returns floats
|
"remote::together", # param silently ignored, always returns floats
|
||||||
|
"remote::fireworks", # param silently ignored, always returns list of floats
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
|
||||||
|
|
||||||
|
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="base64",
|
encoding_format="base64",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response structure
|
# Validate response structure
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
assert response.model == embedding_model_id
|
assert response.model == embedding_model_id
|
||||||
|
|
|
@ -108,6 +108,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
|
"embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"fireworks": Setup(
|
||||||
|
name="fireworks",
|
||||||
|
description="Fireworks provider with a text model",
|
||||||
|
defaults={
|
||||||
|
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
|
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||||
|
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue