mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 22:02:16 +00:00
todos
This commit is contained in:
parent
011fd59a29
commit
8a576d7d72
24 changed files with 297 additions and 2525 deletions
|
|
@ -8,19 +8,12 @@ import time
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
|
@ -42,12 +35,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.apis.tools import (
|
||||
|
|
@ -94,9 +81,7 @@ class VectorIORouter(VectorIO):
|
|||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
||||
)
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
|
@ -114,9 +99,7 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||
vector_db_id, chunks, ttl_seconds
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||
|
||||
async def query_chunks(
|
||||
self,
|
||||
|
|
@ -125,9 +108,7 @@ class VectorIORouter(VectorIO):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||
vector_db_id, query, params
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
|
@ -164,9 +145,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
|
|
@ -220,16 +199,11 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model
|
||||
)
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
return [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||
for metric in metrics
|
||||
]
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
|
|
@ -254,9 +228,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
|
|
@ -266,19 +238,12 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
if (
|
||||
tool_prompt_format
|
||||
and tool_prompt_format != tool_config.tool_prompt_format
|
||||
):
|
||||
raise ValueError(
|
||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||
)
|
||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||
else:
|
||||
params = {}
|
||||
if tool_choice:
|
||||
|
|
@ -296,14 +261,9 @@ class InferenceRouter(Inference):
|
|||
pass
|
||||
else:
|
||||
# verify tool_choice is one of the tools
|
||||
tool_names = [
|
||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||
for t in tools
|
||||
]
|
||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||
if tool_config.tool_choice not in tool_names:
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||
)
|
||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
|
@ -318,25 +278,17 @@ class InferenceRouter(Inference):
|
|||
tool_config=tool_config,
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
prompt_tokens = await self._count_tokens(
|
||||
messages, tool_config.tool_prompt_format
|
||||
)
|
||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.progress
|
||||
):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if (
|
||||
chunk.event.event_type
|
||||
== ChatCompletionResponseEventType.complete
|
||||
):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
|
|
@ -353,11 +305,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
|
@ -374,9 +322,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def completion(
|
||||
|
|
@ -397,9 +343,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
|
@ -419,11 +363,7 @@ class InferenceRouter(Inference):
|
|||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if (
|
||||
hasattr(chunk, "stop_reason")
|
||||
and chunk.stop_reason
|
||||
and self.telemetry
|
||||
):
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
|
|
@ -432,11 +372,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = (
|
||||
metrics
|
||||
if chunk.metrics is None
|
||||
else chunk.metrics + metrics
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
|
|
@ -450,9 +386,7 @@ class InferenceRouter(Inference):
|
|||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = (
|
||||
metrics if response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def embeddings(
|
||||
|
|
@ -468,9 +402,7 @@ class InferenceRouter(Inference):
|
|||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
|
@ -504,9 +436,7 @@ class SafetyRouter(Safety):
|
|||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(
|
||||
shield_id, provider_shield_id, provider_id, params
|
||||
)
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
@ -607,9 +537,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(
|
||||
"insert_into_memory"
|
||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||
documents, vector_db_id, chunk_size_in_tokens
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -642,6 +572,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue