mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 20:27:35 +00:00
Merge branch 'main' into prompt-api
This commit is contained in:
commit
14f7b0d843
8 changed files with 92 additions and 13 deletions
|
@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
|
@ -30,6 +31,11 @@ providers:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -95,6 +101,8 @@ models:
|
||||||
- model_id: ${env.INFERENCE_MODEL}
|
- model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -527,7 +527,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# Store the response with the ID that will be returned to the client
|
# Store the response with the ID that will be returned to the client
|
||||||
if self.store:
|
if self.store:
|
||||||
await self.store.store_chat_completion(response, messages)
|
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||||
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
|
@ -855,4 +855,4 @@ class InferenceRouter(Inference):
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||||
await self.store.store_chat_completion(final_response, messages)
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||||
|
|
|
@ -141,6 +141,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||||
|
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||||
|
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
|
|
|
@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="fireworks",
|
adapter_type="fireworks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"fireworks-ai<=0.18.0",
|
"fireworks-ai<=0.17.16",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.fireworks",
|
module="llama_stack.providers.remote.inference.fireworks",
|
||||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(
|
embeddings = await asyncio.to_thread(
|
||||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
embedding_model.encode,
|
||||||
|
[interleaved_content_as_str(content) for content in contents],
|
||||||
|
show_progress_bar=False,
|
||||||
)
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
||||||
# Get the model and generate embeddings
|
# Get the model and generate embeddings
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||||
|
|
||||||
# Convert embeddings to the requested format
|
# Convert embeddings to the requested format
|
||||||
data = []
|
data = []
|
||||||
|
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
|
||||||
|
def _load_model():
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
return SentenceTransformer(model)
|
||||||
|
|
||||||
|
loaded_model = await asyncio.to_thread(_load_model)
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
|
@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
if i == len(connection_strategies) - 1:
|
if i == len(connection_strategies) - 1:
|
||||||
raise
|
raise
|
||||||
|
except* httpx.ConnectError as eg:
|
||||||
|
# Connection refused, server down, network unreachable
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
||||||
|
logger.error(f"MCP connection error: {error_msg}")
|
||||||
|
raise ConnectionError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
|
except* httpx.TimeoutException as eg:
|
||||||
|
# Request timeout, server too slow
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
error_msg = f"MCP server at {endpoint} timed out"
|
||||||
|
logger.error(f"MCP timeout error: {error_msg}")
|
||||||
|
raise TimeoutError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
|
except* httpx.RequestError as eg:
|
||||||
|
# DNS resolution failures, network errors, invalid URLs
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
# Get the first exception's message for the error string
|
||||||
|
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
||||||
|
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
||||||
|
logger.error(f"MCP network error: {error_msg}")
|
||||||
|
raise ConnectionError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
except* McpError:
|
except* McpError:
|
||||||
if i < len(connection_strategies) - 1:
|
if i < len(connection_strategies) - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..test_cases.test_case import TestCase
|
from ..test_cases.test_case import TestCase
|
||||||
|
@ -323,8 +325,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
tries = 0
|
||||||
|
while tries < 10:
|
||||||
responses = client.chat.completions.list(limit=1000)
|
responses = client.chat.completions.list(limit=1000)
|
||||||
assert response_id in [r.id for r in responses.data]
|
if response_id in [r.id for r in responses.data]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tries += 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert tries < 10, f"Response {response_id} not found after 1 second"
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
|
@ -388,6 +397,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# wait for the response to be stored
|
||||||
|
tries = 0
|
||||||
|
while tries < 10:
|
||||||
|
responses = client.chat.completions.list(limit=1000)
|
||||||
|
if response_id in [r.id for r in responses.data]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tries += 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
assert tries < 10, f"Response {response_id} not found after 1 second"
|
||||||
|
|
||||||
responses = client.chat.completions.list(limit=1000)
|
responses = client.chat.completions.list(limit=1000)
|
||||||
assert response_id in [r.id for r in responses.data]
|
assert response_id in [r.id for r in responses.data]
|
||||||
|
|
||||||
|
|
|
@ -113,6 +113,15 @@ class TestTranslateException:
|
||||||
assert result.status_code == 504
|
assert result.status_code == 504
|
||||||
assert result.detail == "Operation timed out: "
|
assert result.detail == "Operation timed out: "
|
||||||
|
|
||||||
|
def test_translate_connection_error(self):
|
||||||
|
"""Test that ConnectionError is translated to 502 HTTP status."""
|
||||||
|
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
|
||||||
|
result = translate_exception(exc)
|
||||||
|
|
||||||
|
assert isinstance(result, HTTPException)
|
||||||
|
assert result.status_code == 502
|
||||||
|
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
|
||||||
|
|
||||||
def test_translate_not_implemented_error(self):
|
def test_translate_not_implemented_error(self):
|
||||||
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
||||||
exc = NotImplementedError("Not implemented")
|
exc = NotImplementedError("Not implemented")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue