mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix
This commit is contained in:
parent
9b38ae9323
commit
f262bfd061
10 changed files with 107 additions and 319 deletions
|
@ -8,9 +8,9 @@ import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
@ -94,9 +94,7 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
|
||||||
)
|
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -114,9 +112,7 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
vector_db_id, chunks, ttl_seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -125,9 +121,7 @@ class VectorIORouter(VectorIO):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
vector_db_id, query, params
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -164,9 +158,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
model_id, provider_model_id, provider_id, metadata, model_type
|
|
||||||
)
|
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -220,16 +212,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> List[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
prompt_tokens, completion_tokens, total_tokens, model
|
|
||||||
)
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
return [
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value)
|
|
||||||
for metric in metrics
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -254,9 +241,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -266,19 +251,12 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if (
|
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
||||||
tool_prompt_format
|
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
||||||
and tool_prompt_format != tool_config.tool_prompt_format
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -296,14 +274,9 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [
|
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
||||||
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
|
||||||
for t in tools
|
|
||||||
]
|
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(
|
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
||||||
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -318,25 +291,17 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(
|
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||||
messages, tool_config.tool_prompt_format
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.progress
|
|
||||||
):
|
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if (
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
chunk.event.event_type
|
|
||||||
== ChatCompletionResponseEventType.complete
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[
|
[
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
@ -353,11 +318,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -374,9 +335,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -397,9 +356,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
|
||||||
)
|
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -419,11 +376,7 @@ class InferenceRouter(Inference):
|
||||||
async for chunk in await provider.completion(**params):
|
async for chunk in await provider.completion(**params):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if (
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||||
hasattr(chunk, "stop_reason")
|
|
||||||
and chunk.stop_reason
|
|
||||||
and self.telemetry
|
|
||||||
):
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
metrics = await self._compute_and_log_token_usage(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
@ -432,11 +385,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = (
|
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||||
metrics
|
|
||||||
if chunk.metrics is None
|
|
||||||
else chunk.metrics + metrics
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -450,9 +399,7 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = (
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -468,9 +415,7 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -504,9 +449,7 @@ class SafetyRouter(Safety):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
shield_id, provider_shield_id, provider_id, params
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -603,9 +546,7 @@ class ScoringRouter(Scoring):
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
fn_identifier
|
|
||||||
).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -623,15 +564,11 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
|
|
||||||
)
|
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
fn_identifier
|
|
||||||
).score(
|
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
)
|
)
|
||||||
|
@ -674,9 +611,7 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
|
@ -690,9 +625,7 @@ class EvalRouter(Eval):
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
benchmark_id, job_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
self,
|
self,
|
||||||
|
@ -746,9 +679,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
"insert_into_memory"
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -781,6 +714,4 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
tool_group_id, mcp_endpoint
|
|
||||||
)
|
|
||||||
|
|
|
@ -105,9 +105,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(
|
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
|
||||||
) -> None:
|
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -142,9 +140,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(
|
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||||
self, routing_key: str, provider_id: Optional[str] = None
|
|
||||||
) -> Any:
|
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -182,9 +178,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(
|
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||||
self, type: str, identifier: str
|
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
|
@ -194,13 +188,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
obj, self.impls_by_provider_id[obj.provider_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def register_object(
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
self, obj: RoutableObjectWithProvider
|
|
||||||
) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -255,9 +245,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError(
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
"Embedding model must have an embedding dimension in its metadata"
|
|
||||||
)
|
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
@ -277,9 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return ListShieldsResponse(
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||||
return await self.get_object_by_identifier("shield", identifier)
|
return await self.get_object_by_identifier("shield", identifier)
|
||||||
|
@ -338,18 +324,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
"No provider available. Please configure a vector_io provider."
|
|
||||||
)
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
f"Model {embedding_model} does not have an embedding dimension"
|
|
||||||
)
|
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_db_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
|
@ -371,9 +353,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return ListDatasetsResponse(
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -426,9 +406,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
return ListScoringFunctionsResponse(
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
@ -525,12 +503,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
toolgroup_id, mcp_endpoint
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
)
|
|
||||||
tool_host = (
|
|
||||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
|
|
|
@ -230,9 +230,7 @@ def run_evaluation_3():
|
||||||
output_res[scoring_fn] = []
|
output_res[scoring_fn] = []
|
||||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||||
|
|
||||||
progress_text_container.write(
|
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
|
||||||
f"Expand to see current processed result ({i + 1} / {len(rows)})"
|
|
||||||
)
|
|
||||||
results_container.json(eval_res, expanded=2)
|
results_container.json(eval_res, expanded=2)
|
||||||
|
|
||||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||||
|
|
|
@ -161,9 +161,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
|
|
||||||
new_rows_df = pandas.DataFrame(rows)
|
new_rows_df = pandas.DataFrame(rows)
|
||||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||||
dataset_impl.df = pandas.concat(
|
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
url = str(dataset_info.dataset_def.url.uri)
|
url = str(dataset_info.dataset_def.url.uri)
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
|
@ -178,12 +176,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||||
|
|
||||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
|
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
|
||||||
"utf-8"
|
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
|
||||||
)
|
|
||||||
dataset_info.dataset_def.url = URL(
|
|
||||||
uri=f"data:text/csv;base64,{base64_content}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||||
|
|
|
@ -89,16 +89,10 @@ class MetaReferenceEvalImpl(
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
|
|
||||||
)
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=(
|
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
-1
|
|
||||||
if benchmark_config.num_examples is None
|
|
||||||
else benchmark_config.num_examples
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
|
@ -124,14 +118,10 @@ class MetaReferenceEvalImpl(
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [
|
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||||
UserMessage(**x) for x in input_messages if x["role"] == "user"
|
|
||||||
]
|
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
session_create_response = await self.agents_api.create_agent_session(
|
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||||
agent_id, f"session-{i}"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
turn_request = dict(
|
turn_request = dict(
|
||||||
|
@ -140,12 +130,7 @@ class MetaReferenceEvalImpl(
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
turn_response = [
|
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||||
chunk
|
|
||||||
async for chunk in await self.agents_api.create_agent_turn(
|
|
||||||
**turn_request
|
|
||||||
)
|
|
||||||
]
|
|
||||||
final_event = turn_response[-1].event.payload
|
final_event = turn_response[-1].event.payload
|
||||||
|
|
||||||
# check if there's a memory retrieval step and extract the context
|
# check if there's a memory retrieval step and extract the context
|
||||||
|
@ -154,14 +139,10 @@ class MetaReferenceEvalImpl(
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value:
|
||||||
for tool_response in step.tool_responses:
|
for tool_response in step.tool_responses:
|
||||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||||
memory_rag_context = " ".join(
|
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||||
x.text for x in tool_response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_generation = {}
|
agent_generation = {}
|
||||||
agent_generation[ColumnName.generated_answer.value] = (
|
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||||
final_event.turn.output_message.content
|
|
||||||
)
|
|
||||||
if memory_rag_context:
|
if memory_rag_context:
|
||||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||||
|
|
||||||
|
@ -173,9 +154,7 @@ class MetaReferenceEvalImpl(
|
||||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = benchmark_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
assert (
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||||
candidate.sampling_params.max_tokens is not None
|
|
||||||
), "SamplingParams.max_tokens must be provided"
|
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in tqdm(input_rows):
|
for x in tqdm(input_rows):
|
||||||
|
@ -186,39 +165,21 @@ class MetaReferenceEvalImpl(
|
||||||
content=input_content,
|
content=input_content,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_json = json.loads(
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
x[ColumnName.chat_completion_input.value]
|
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||||
)
|
|
||||||
input_messages = [
|
|
||||||
UserMessage(**x)
|
|
||||||
for x in chat_completion_input_json
|
|
||||||
if x["role"] == "user"
|
|
||||||
]
|
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
messages += [
|
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||||
SystemMessage(**x)
|
|
||||||
for x in chat_completion_input_json
|
|
||||||
if x["role"] == "system"
|
|
||||||
]
|
|
||||||
messages += input_messages
|
messages += input_messages
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model_id=candidate.model,
|
model_id=candidate.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
generations.append(
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
{
|
|
||||||
ColumnName.generated_answer.value: response.completion_message.content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid input row")
|
raise ValueError("Invalid input row")
|
||||||
|
|
||||||
|
@ -241,8 +202,7 @@ class MetaReferenceEvalImpl(
|
||||||
|
|
||||||
# scoring with generated_answer
|
# scoring with generated_answer
|
||||||
score_input_rows = [
|
score_input_rows = [
|
||||||
input_r | generated_r
|
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||||
for input_r, generated_r in zip(input_rows, generations, strict=False)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if benchmark_config.scoring_params is not None:
|
if benchmark_config.scoring_params is not None:
|
||||||
|
@ -251,9 +211,7 @@ class MetaReferenceEvalImpl(
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
|
||||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
|
||||||
}
|
|
||||||
|
|
||||||
score_response = await self.scoring_api.score(
|
score_response = await self.scoring_api.score(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
|
|
|
@ -17,7 +17,8 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
from torchtune import modules, training
|
||||||
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.data import padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
|
@ -88,9 +89,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self.job_uuid = job_uuid
|
self.job_uuid = job_uuid
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
raise ValueError(
|
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
|
||||||
)
|
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device()
|
self._device = torchtune_utils.get_device()
|
||||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
|
@ -99,10 +98,7 @@ class LoraFinetuningSingleDevice:
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
paths = [
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
|
||||||
for ext in ["pth", "00.pth"]
|
|
||||||
]
|
|
||||||
if not any(p.exists() for p in paths):
|
if not any(p.exists() for p in paths):
|
||||||
checkpoint_dir = checkpoint_dir / "original"
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
@ -117,9 +113,7 @@ class LoraFinetuningSingleDevice:
|
||||||
else:
|
else:
|
||||||
model = resolve_model(self.model_id)
|
model = resolve_model(self.model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(
|
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
|
||||||
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
|
|
||||||
)
|
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
|
@ -191,9 +185,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._tokenizer = await self._setup_tokenizer()
|
self._tokenizer = await self._setup_tokenizer()
|
||||||
log.info("Tokenizer is initialized.")
|
log.info("Tokenizer is initialized.")
|
||||||
|
|
||||||
self._optimizer = await self._setup_optimizer(
|
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
|
||||||
optimizer_config=self.training_config.optimizer_config
|
|
||||||
)
|
|
||||||
log.info("Optimizer is initialized.")
|
log.info("Optimizer is initialized.")
|
||||||
|
|
||||||
self._loss_fn = CEWithChunkedOutputLoss()
|
self._loss_fn = CEWithChunkedOutputLoss()
|
||||||
|
@ -221,13 +213,8 @@ class LoraFinetuningSingleDevice:
|
||||||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||||
# for logging and tracking training state. This should be computed after the dataloader
|
# for logging and tracking training state. This should be computed after the dataloader
|
||||||
# has been setup
|
# has been setup
|
||||||
self._steps_per_epoch = (
|
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||||
len(self._training_dataloader) // self._gradient_accumulation_steps
|
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.max_steps_per_epoch is not None
|
|
||||||
and self.max_steps_per_epoch < self._steps_per_epoch
|
|
||||||
):
|
|
||||||
self._steps_per_epoch = self.max_steps_per_epoch
|
self._steps_per_epoch = self.max_steps_per_epoch
|
||||||
self.global_step = self.epochs_run * self._steps_per_epoch
|
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||||
|
|
||||||
|
@ -241,9 +228,7 @@ class LoraFinetuningSingleDevice:
|
||||||
log.info("Learning rate scheduler is initialized.")
|
log.info("Learning rate scheduler is initialized.")
|
||||||
|
|
||||||
# Used to ignore labels for loss computation
|
# Used to ignore labels for loss computation
|
||||||
self.ignore_labels_cache = torch.full(
|
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_memory_stats(self):
|
def _log_memory_stats(self):
|
||||||
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||||
|
@ -284,13 +269,9 @@ class LoraFinetuningSingleDevice:
|
||||||
set_trainable_params(model, self.adapter_params)
|
set_trainable_params(model, self.adapter_params)
|
||||||
|
|
||||||
if enable_activation_checkpointing:
|
if enable_activation_checkpointing:
|
||||||
training.set_activation_checkpointing(
|
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
|
||||||
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
|
||||||
)
|
|
||||||
|
|
||||||
base_missing, base_unexpected = model.load_state_dict(
|
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
|
||||||
base_model_state_dict, strict=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is for any adapters that need to be initialized after base weights
|
# This is for any adapters that need to be initialized after base weights
|
||||||
# have been loaded (e.g. DoRA).
|
# have been loaded (e.g. DoRA).
|
||||||
|
@ -299,9 +280,7 @@ class LoraFinetuningSingleDevice:
|
||||||
if hasattr(m, "initialize_dora_magnitude"):
|
if hasattr(m, "initialize_dora_magnitude"):
|
||||||
m.initialize_dora_magnitude()
|
m.initialize_dora_magnitude()
|
||||||
if lora_weights_state_dict:
|
if lora_weights_state_dict:
|
||||||
lora_missing, lora_unexpected = model.load_state_dict(
|
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
||||||
lora_weights_state_dict, strict=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
lora_missing, lora_unexpected = None, None
|
lora_missing, lora_unexpected = None, None
|
||||||
validate_missing_and_unexpected_for_lora(
|
validate_missing_and_unexpected_for_lora(
|
||||||
|
@ -315,14 +294,10 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate model adapter params were loaded in with the expected dtype
|
# Validate model adapter params were loaded in with the expected dtype
|
||||||
training.validate_expected_param_dtype(
|
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
|
||||||
self.adapter_params.items(), dtype=self._dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# activation offloading
|
# activation offloading
|
||||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||||
model, enable_activation_offloading
|
|
||||||
)
|
|
||||||
|
|
||||||
self._log_memory_stats()
|
self._log_memory_stats()
|
||||||
|
|
||||||
|
@ -458,9 +433,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Shift labels to compute loss
|
# Shift labels to compute loss
|
||||||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||||
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||||
labels = torch.hstack(
|
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
|
||||||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
|
||||||
)
|
|
||||||
if not isinstance(logits, list):
|
if not isinstance(logits, list):
|
||||||
labels = labels.reshape(-1)
|
labels = labels.reshape(-1)
|
||||||
logits = logits.reshape(-1, logits.size(-1))
|
logits = logits.reshape(-1, logits.size(-1))
|
||||||
|
@ -489,9 +462,7 @@ class LoraFinetuningSingleDevice:
|
||||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
# in case shuffle is True
|
# in case shuffle is True
|
||||||
metric_logger = DiskLogger(
|
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
|
||||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
|
|
||||||
)
|
|
||||||
self._training_sampler.set_epoch(curr_epoch)
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
loss_to_log = 0.0
|
loss_to_log = 0.0
|
||||||
|
|
||||||
|
@ -499,8 +470,7 @@ class LoraFinetuningSingleDevice:
|
||||||
for idx, batch in enumerate(self._training_dataloader):
|
for idx, batch in enumerate(self._training_dataloader):
|
||||||
if (
|
if (
|
||||||
self.max_steps_per_epoch is not None
|
self.max_steps_per_epoch is not None
|
||||||
and (idx // self._gradient_accumulation_steps)
|
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
|
||||||
== self.max_steps_per_epoch
|
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -508,9 +478,7 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
# Calculate the number of unmasked tokens in the current batch
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
# and increment the total number of tokens seen in the step
|
# and increment the total number of tokens seen in the step
|
||||||
current_num_tokens = (
|
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||||
batch["labels"] != self._loss_fn.ignore_index
|
|
||||||
).sum()
|
|
||||||
num_tokens += current_num_tokens
|
num_tokens += current_num_tokens
|
||||||
|
|
||||||
# Loss is normalized by default so we multiply by the number of tokens
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
|
@ -535,9 +503,7 @@ class LoraFinetuningSingleDevice:
|
||||||
loss_to_log = running_loss.item() / num_tokens
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_description(
|
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
|
||||||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
|
||||||
)
|
|
||||||
|
|
||||||
time_per_step = time.perf_counter() - t0
|
time_per_step = time.perf_counter() - t0
|
||||||
log_dict = {
|
log_dict = {
|
||||||
|
|
|
@ -64,15 +64,11 @@ class BasicScoringImpl(
|
||||||
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = [
|
scoring_fn_defs_list = [
|
||||||
fn_def
|
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||||
for impl in self.scoring_fn_id_impls.values()
|
|
||||||
for fn_def in impl.get_supported_scoring_fn_defs()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith("basic"), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
||||||
"basic"
|
|
||||||
), "All basic scoring fn must have identifier prefixed with 'basic'! "
|
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
|
@ -86,9 +82,7 @@ class BasicScoringImpl(
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -118,12 +112,8 @@ class BasicScoringImpl(
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||||
score_results = await scoring_fn.score(
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||||
input_rows, scoring_fn_id, scoring_fn_params
|
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||||
)
|
|
||||||
agg_results = await scoring_fn.aggregate(
|
|
||||||
score_results, scoring_fn_id, scoring_fn_params
|
|
||||||
)
|
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -122,12 +122,10 @@ class BraintrustScoringImpl(
|
||||||
self.datasets_api = datasets_api
|
self.datasets_api = datasets_api
|
||||||
|
|
||||||
self.braintrust_evaluators = {
|
self.braintrust_evaluators = {
|
||||||
entry.identifier: entry.evaluator
|
entry.identifier: entry.evaluator for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
|
||||||
}
|
}
|
||||||
self.supported_fn_defs_registry = {
|
self.supported_fn_defs_registry = {
|
||||||
entry.identifier: entry.fn_def
|
entry.identifier: entry.fn_def for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
||||||
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
@ -137,16 +135,14 @@ class BraintrustScoringImpl(
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||||
for f in scoring_fn_defs_list:
|
for f in scoring_fn_defs_list:
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith("braintrust"), (
|
||||||
"braintrust"
|
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||||
), "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
)
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Registering scoring function not allowed for braintrust provider")
|
||||||
"Registering scoring function not allowed for braintrust provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_api_key(self) -> None:
|
async def set_api_key(self) -> None:
|
||||||
# api key is in the request headers
|
# api key is in the request headers
|
||||||
|
@ -169,17 +165,13 @@ class BraintrustScoringImpl(
|
||||||
await self.set_api_key()
|
await self.set_api_key()
|
||||||
|
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
res = await self.score(
|
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
|
||||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
|
||||||
)
|
|
||||||
if save_results_dataset:
|
if save_results_dataset:
|
||||||
# TODO: persist and register dataset on to server for reading
|
# TODO: persist and register dataset on to server for reading
|
||||||
# self.datasets_api.register_dataset()
|
# self.datasets_api.register_dataset()
|
||||||
|
@ -220,13 +212,8 @@ class BraintrustScoringImpl(
|
||||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
|
|
||||||
score_results = [
|
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
||||||
await self.score_row(input_row, scoring_fn_id)
|
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
||||||
for input_row in input_rows
|
|
||||||
]
|
|
||||||
aggregation_functions = self.supported_fn_defs_registry[
|
|
||||||
scoring_fn_id
|
|
||||||
].params.aggregation_functions
|
|
||||||
|
|
||||||
# override scoring_fn params if provided
|
# override scoring_fn params if provided
|
||||||
if scoring_functions[scoring_fn_id] is not None:
|
if scoring_functions[scoring_fn_id] is not None:
|
||||||
|
|
|
@ -54,9 +54,9 @@ class LlmAsJudgeScoringImpl(
|
||||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||||
|
|
||||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||||
assert f.identifier.startswith(
|
assert f.identifier.startswith("llm-as-judge"), (
|
||||||
"llm-as-judge"
|
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
||||||
), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
|
)
|
||||||
|
|
||||||
return scoring_fn_defs_list
|
return scoring_fn_defs_list
|
||||||
|
|
||||||
|
@ -70,9 +70,7 @@ class LlmAsJudgeScoringImpl(
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
||||||
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
all_rows = await self.datasetio_api.iterrows(
|
all_rows = await self.datasetio_api.iterrows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -100,12 +98,8 @@ class LlmAsJudgeScoringImpl(
|
||||||
for scoring_fn_id in scoring_functions.keys():
|
for scoring_fn_id in scoring_functions.keys():
|
||||||
scoring_fn = self.llm_as_judge_fn
|
scoring_fn = self.llm_as_judge_fn
|
||||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||||
score_results = await scoring_fn.score(
|
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
||||||
input_rows, scoring_fn_id, scoring_fn_params
|
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
||||||
)
|
|
||||||
agg_results = await scoring_fn.aggregate(
|
|
||||||
score_results, scoring_fn_id, scoring_fn_params
|
|
||||||
)
|
|
||||||
res[scoring_fn_id] = ScoringResult(
|
res[scoring_fn_id] = ScoringResult(
|
||||||
score_rows=score_results,
|
score_rows=score_results,
|
||||||
aggregated_results=agg_results,
|
aggregated_results=agg_results,
|
||||||
|
|
|
@ -104,13 +104,9 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||||
|
|
||||||
# Concatenate the new rows with existing dataset
|
# Concatenate the new rows with existing dataset
|
||||||
updated_dataset = hf_datasets.concatenate_datasets(
|
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
||||||
[loaded_dataset, new_dataset]
|
|
||||||
)
|
|
||||||
|
|
||||||
if dataset_def.metadata.get("path", None):
|
if dataset_def.metadata.get("path", None):
|
||||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
||||||
"Uploading to URL-based datasets is not supported yet"
|
|
||||||
)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue