Merge branch 'refs/heads/main' into preprocessors

This commit is contained in:
ilya-kolchinsky 2025-03-11 20:05:52 +01:00
commit d38aea33c1
37 changed files with 493 additions and 255 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps,
_deps: Dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Union
from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
_deps: Dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Dict
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune(
self,
job_uuid: str,

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps):
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -3,11 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,