From be3c5c034d2f94d1816bb719e3776f13cbc4896a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 17 Oct 2024 17:28:17 -0700 Subject: [PATCH 01/10] [bugfix] fix case for agent when memory bank registered without specifying provider_id (#264) * fix case where memory bank is registered without provider_id * memory test * agents unit test --- llama_stack/apis/memory_banks/client.py | 15 +++ .../distribution/routers/routing_tables.py | 10 +- .../tests/agents/provider_config_example.yaml | 2 +- .../providers/tests/agents/test_agents.py | 101 ++++++++++++++++++ .../tests/memory/provider_config_example.yaml | 4 +- .../providers/tests/memory/test_memory.py | 24 +++++ 6 files changed, 151 insertions(+), 5 deletions(-) diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 588a93fe2..69be35d02 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool): response = await client.list_memory_banks() cprint(f"list_memory_banks response={response}", "green") + # register memory bank for the first time + response = await client.register_memory_bank( + VectorMemoryBankDef( + identifier="test_bank2", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + ) + cprint(f"register_memory_bank response={response}", "blue") + + # list again after registering + response = await client.list_memory_banks() + cprint(f"list_memory_banks response={response}", "green") + def main(host: str, port: int, stream: bool = True): asyncio.run(run_main(host, port, stream)) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 17755f0e4..ede30aea1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object(self, obj: RoutableObjectWithProvider): entries = self.registry.get(obj.identifier, []) for entry in entries: - if entry.provider_id == obj.provider_id: - print(f"`{obj.identifier}` already registered with `{obj.provider_id}`") + if entry.provider_id == obj.provider_id or not obj.provider_id: + print( + f"`{obj.identifier}` already registered with `{entry.provider_id}`" + ) return + # if provider_id is not specified, we'll pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + if obj.provider_id not in self.impls_by_provider_id: raise ValueError(f"Provider `{obj.provider_id}` not found") diff --git a/llama_stack/providers/tests/agents/provider_config_example.yaml b/llama_stack/providers/tests/agents/provider_config_example.yaml index 5b643590c..58f05e29a 100644 --- a/llama_stack/providers/tests/agents/provider_config_example.yaml +++ b/llama_stack/providers/tests/agents/provider_config_example.yaml @@ -31,4 +31,4 @@ providers: persistence_store: namespace: null type: sqlite - db_path: /Users/ashwin/.llama/runtime/kvstore.db + db_path: ~/.llama/runtime/kvstore.db diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index edcc6adea..6774d3f1f 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -64,6 +64,24 @@ def search_query_messages(): ] +@pytest.fixture +def attachment_message(): + return [ + UserMessage( + content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", + ), + ] + + +@pytest.fixture +def query_attachment_messages(): + return [ + UserMessage( + content="What are the top 5 topics that were explained? Only list succinct bullet points." + ), + ] + + @pytest.mark.asyncio async def test_create_agent_turn(agents_settings, sample_messages): agents_impl = agents_settings["impl"] @@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages): assert len(final_event.turn.output_message.content) > 0 +@pytest.mark.asyncio +async def test_rag_agent_as_attachments( + agents_settings, attachment_message, query_attachment_messages +): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + + attachments = [ + Attachment( + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + ) + for i, url in enumerate(urls) + ] + + agents_impl = agents_settings["impl"] + + agent_config = AgentConfig( + model=agents_settings["common_params"]["model"], + instructions=agents_settings["common_params"]["instructions"], + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[ + MemoryToolDefinition( + memory_bank_configs=[], + query_generator_config={ + "type": "default", + "sep": " ", + }, + max_tokens_in_context=4096, + max_chunks=10, + ), + ], + max_infer_iters=5, + ) + + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=attachment_message, + attachments=attachments, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + # Create a second turn querying the agent + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=query_attachment_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + ] + + assert len(turn_response) > 0 + + @pytest.mark.asyncio async def test_create_agent_turn_with_brave_search( agents_settings, search_query_messages diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index cac1adde5..5b5440f8d 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -2,8 +2,8 @@ providers: - provider_id: test-faiss provider_type: meta-reference config: {} - - provider_id: test-chroma - provider_type: remote::chroma + - provider_id: test-chromadb + provider_type: remote::chromadb config: host: localhost port: 6001 diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index c5ebdf9c7..d92feaba8 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -89,6 +89,30 @@ async def test_banks_list(memory_settings): assert len(response) == 0 +@pytest.mark.asyncio +async def test_banks_register(memory_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + banks_impl = memory_settings["memory_banks_impl"] + bank = VectorMemoryBankDef( + identifier="test_bank_no_provider", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ) + + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + # register same memory bank with same id again will fail + await banks_impl.register_memory_bank(bank) + response = await banks_impl.list_memory_banks() + assert isinstance(response, list) + assert len(response) == 1 + + @pytest.mark.asyncio async def test_query_documents(memory_settings, sample_documents): memory_impl = memory_settings["memory_impl"] From 33afd34e6f557d2a9aae762a590cc75c89bcc029 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 12:51:10 -0700 Subject: [PATCH 02/10] Add an option to not use elastic agents for meta-reference inference (#269) --- .../impls/meta_reference/inference/config.py | 7 +++- .../meta_reference/inference/inference.py | 34 +++++++++++++++---- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 901a8c7fb..4e1161ced 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -17,13 +17,18 @@ from llama_stack.providers.utils.inference import supported_inference_models class MetaReferenceInferenceConfig(BaseModel): model: str = Field( - default="Llama3.1-8B-Instruct", + default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ) torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 + # when this is False, we assume that the distributed process group is setup by someone + # outside of this code (e.g., when run inside `torchrun`). that is useful for clients + # (including our testing code) who might be using llama-stack as a library. + create_distributed_process_group: bool = True + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 6696762c9..7edc279d0 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -18,6 +18,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .config import MetaReferenceInferenceConfig +from .generation import Llama from .model_parallel import LlamaModelParallelGenerator # there's a single model parallel process running serving the model. for now, @@ -36,8 +37,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def initialize(self) -> None: print(f"Loading model `{self.model.descriptor()}`") - self.generator = LlamaModelParallelGenerator(self.config) - self.generator.start() + if self.config.create_distributed_process_group: + self.generator = LlamaModelParallelGenerator(self.config) + self.generator.start() + else: + self.generator = Llama.build(self.config) async def register_model(self, model: ModelDef) -> None: raise ValueError("Dynamic model registration is not supported") @@ -51,7 +55,8 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): ] async def shutdown(self) -> None: - self.generator.stop() + if self.config.create_distributed_process_group: + self.generator.stop() def completion( self, @@ -99,8 +104,9 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): f"Model mismatch: {request.model} != {self.model.descriptor()}" ) - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") + if self.config.create_distributed_process_group: + if SEMAPHORE.locked(): + raise RuntimeError("Only one concurrent request is supported") if request.stream: return self._stream_chat_completion(request) @@ -110,7 +116,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - async with SEMAPHORE: + def impl(): messages = chat_completion_request_to_messages(request) tokens = [] @@ -154,10 +160,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): logprobs=logprobs if request.logprobs else None, ) + if self.config.create_distributed_process_group: + async with SEMAPHORE: + return impl() + else: + return impl() + async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - async with SEMAPHORE: + def impl(): messages = chat_completion_request_to_messages(request) yield ChatCompletionResponseStreamChunk( @@ -272,6 +284,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): ) ) + if self.config.create_distributed_process_group: + async with SEMAPHORE: + for x in impl(): + yield x + else: + for x in impl(): + yield x + async def embeddings( self, model: str, From 71a905e93f06b7779d37755d0c8831513f54cb8f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 14:28:06 -0700 Subject: [PATCH 03/10] Allow overridding checkpoint_dir via config --- .../impls/meta_reference/inference/config.py | 4 ++++ .../meta_reference/inference/generation.py | 21 +++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 4e1161ced..48cba645b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -29,6 +29,10 @@ class MetaReferenceInferenceConfig(BaseModel): # (including our testing code) who might be using llama-stack as a library. create_distributed_process_group: bool = True + # By default, the implementation will look at ~/.llama/checkpoints/ but you + # can override by specifying the directory explicitly + checkpoint_dir: Optional[str] = None + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 9037b9acd..20a8addc7 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -98,7 +98,10 @@ class Llama: sys.stdout = open(os.devnull, "w") start_time = time.time() - ckpt_dir = model_checkpoint_dir(model) + if config.checkpoint_dir: + ckpt_dir = config.checkpoint_dir + else: + ckpt_dir = model_checkpoint_dir(model) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -119,9 +122,7 @@ class Llama: **params, ) - tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model") - tokenizer = Tokenizer(model_path=tokenizer_path) - + tokenizer = Tokenizer.get_instance() assert ( model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" @@ -170,14 +171,16 @@ class Llama: logprobs: bool = False, echo: bool = False, include_stop_token: bool = False, + print_input_tokens: bool = False, ) -> Generator: params = self.model.params - # input_tokens = [ - # self.formatter.vision_token if t == 128256 else t - # for t in model_input.tokens - # ] - # cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") + if print_input_tokens: + input_tokens = [ + self.formatter.vision_token if t == 128256 else t + for t in model_input.tokens + ] + cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") prompt_tokens = [model_input.tokens] bsz = 1 From 95a96afe34136f060591df2509004cf9c98701b4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 14:41:38 -0700 Subject: [PATCH 04/10] Small rename --- llama_stack/distribution/resolver.py | 2 +- llama_stack/distribution/server/server.py | 4 ++-- llama_stack/providers/tests/resolver.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index a05e08cd7..78d76e977 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -55,7 +55,7 @@ class ProviderWithSpec(Provider): # TODO: this code is not very straightforward to follow and needs one more round of refactoring -async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: +async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index eba89e393..6154432b6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -37,7 +37,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls_with_routing +from llama_stack.distribution.resolver import resolve_impls from .endpoints import get_all_api_endpoints @@ -276,7 +276,7 @@ def main( app = FastAPI() - impls = asyncio.run(resolve_impls_with_routing(config)) + impls = asyncio.run(resolve_impls(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index fabb245e7..de672b6dc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -14,7 +14,7 @@ import yaml from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls_with_routing +from llama_stack.distribution.resolver import resolve_impls async def resolve_impls_for_test(api: Api, deps: List[Api] = None): @@ -36,7 +36,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None): providers=chosen, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls_with_routing(run_config) + impls = await resolve_impls(run_config) if "provider_data" in config_dict: provider_id = chosen[api.value][0].provider_id From 2089427d60be9f17d8de9cadd1e6c0c6cef253fd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 20:50:59 -0700 Subject: [PATCH 05/10] Make all methods `async def` again; add completion() for meta-reference (#270) PR #201 had made several changes while trying to fix issues with getting the stream=False branches of inference and agents API working. As part of this, it made a change which was slightly gratuitous. Namely, making chat_completion() and brethren "def" instead of "async def". The rationale was that this allowed the user (within llama-stack) of this to use it as: ``` async for chunk in api.chat_completion(params) ``` However, it causes unnecessary confusion for several folks. Given that clients (e.g., llama-stack-apps) anyway use the SDK methods (which are completely isolated) this choice was not ideal. Let's revert back so the call now looks like: ``` async for chunk in await api.chat_completion(params) ``` Bonus: Added a completion() implementation for the meta-reference provider. Technically should have been another PR :) --- docs/resources/llama-stack-spec.html | 78 +++++---- docs/resources/llama-stack-spec.yaml | 37 ++-- llama_stack/apis/agents/agents.py | 4 +- llama_stack/apis/agents/client.py | 6 +- llama_stack/apis/inference/client.py | 7 +- llama_stack/apis/inference/inference.py | 13 +- llama_stack/distribution/routers/routers.py | 12 +- .../adapters/inference/bedrock/bedrock.py | 4 +- .../inference/databricks/databricks.py | 6 +- .../adapters/inference/fireworks/fireworks.py | 6 +- .../adapters/inference/ollama/ollama.py | 6 +- .../providers/adapters/inference/tgi/tgi.py | 6 +- .../adapters/inference/together/together.py | 4 +- .../meta_reference/agents/agent_instance.py | 2 +- .../impls/meta_reference/agents/agents.py | 2 +- .../meta_reference/inference/generation.py | 51 +++--- .../meta_reference/inference/inference.py | 160 ++++++++++++++---- .../inference/model_parallel.py | 50 +++--- .../inference/parallel_utils.py | 30 ++-- .../meta_reference/safety/llama_guard.py | 2 +- llama_stack/providers/impls/vllm/vllm.py | 6 +- .../providers/tests/agents/test_agents.py | 8 +- .../tests/inference/test_inference.py | 43 ++++- 23 files changed, 330 insertions(+), 213 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index a2f92b6e4..8e6683931 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988" }, "servers": [ { @@ -2830,8 +2830,11 @@ "CompletionResponse": { "type": "object", "properties": { - "completion_message": { - "$ref": "#/components/schemas/CompletionMessage" + "content": { + "type": "string" + }, + "stop_reason": { + "$ref": "#/components/schemas/StopReason" }, "logprobs": { "type": "array", @@ -2842,7 +2845,8 @@ }, "additionalProperties": false, "required": [ - "completion_message" + "content", + "stop_reason" ], "title": "Completion response." }, @@ -6075,49 +6079,49 @@ ], "tags": [ { - "name": "Evaluations" - }, - { - "name": "Inspect" + "name": "Models" }, { "name": "RewardScoring" }, { - "name": "Datasets" - }, - { - "name": "Models" - }, - { - "name": "Telemetry" - }, - { - "name": "PostTraining" - }, - { - "name": "SyntheticDataGeneration" - }, - { - "name": "BatchInference" - }, - { - "name": "Inference" - }, - { - "name": "Agents" - }, - { - "name": "Memory" - }, - { - "name": "Safety" + "name": "MemoryBanks" }, { "name": "Shields" }, { - "name": "MemoryBanks" + "name": "SyntheticDataGeneration" + }, + { + "name": "Inference" + }, + { + "name": "Inspect" + }, + { + "name": "BatchInference" + }, + { + "name": "Memory" + }, + { + "name": "Datasets" + }, + { + "name": "Agents" + }, + { + "name": "PostTraining" + }, + { + "name": "Telemetry" + }, + { + "name": "Safety" + }, + { + "name": "Evaluations" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index c9822d6ca..906d3934a 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -501,14 +501,17 @@ components: CompletionResponse: additionalProperties: false properties: - completion_message: - $ref: '#/components/schemas/CompletionMessage' + content: + type: string logprobs: items: $ref: '#/components/schemas/TokenLogProbs' type: array + stop_reason: + $ref: '#/components/schemas/StopReason' required: - - completion_message + - content + - stop_reason title: Completion response. type: object CompletionResponseStreamChunk: @@ -2507,7 +2510,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" + \ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3712,21 +3715,21 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Evaluations -- name: Inspect -- name: RewardScoring -- name: Datasets - name: Models -- name: Telemetry -- name: PostTraining -- name: SyntheticDataGeneration -- name: BatchInference -- name: Inference -- name: Agents -- name: Memory -- name: Safety -- name: Shields +- name: RewardScoring - name: MemoryBanks +- name: Shields +- name: SyntheticDataGeneration +- name: Inference +- name: Inspect +- name: BatchInference +- name: Memory +- name: Datasets +- name: Agents +- name: PostTraining +- name: Telemetry +- name: Safety +- name: Evaluations - description: name: BuiltinTool - description: AgentCreateResponse: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`. @webmethod(route="/agents/turn/create") - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index 32bc9abdd..b45447328 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -67,14 +67,14 @@ class AgentsClient(Agents): response.raise_for_status() return AgentSessionCreateResponse(**response.json()) - def create_agent_turn( + async def create_agent_turn( self, request: AgentTurnCreateRequest, ) -> AsyncGenerator: if request.stream: return self._stream_agent_turn(request) else: - return self._nonstream_agent_turn(request) + return await self._nonstream_agent_turn(request) async def _stream_agent_turn( self, request: AgentTurnCreateRequest @@ -126,7 +126,7 @@ async def _run_agent( for content in user_prompts: cprint(f"User> {content}", color="white", attrs=["bold"]) - iterator = api.create_agent_turn( + iterator = await api.create_agent_turn( AgentTurnCreateRequest( agent_id=create_response.agent_id, session_id=session_response.session_id, diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 79d2cc02c..90636fa36 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -42,10 +42,10 @@ class InferenceClient(Inference): async def shutdown(self) -> None: pass - def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -139,7 +139,8 @@ async def run_main( else: logprobs_config = None - iterator = client.chat_completion( + assert stream, "Non streaming not supported here" + iterator = await client.chat_completion( model=model, messages=[message], stream=stream, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 588dd37ca..5895e528e 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -88,7 +88,8 @@ class CompletionRequest(BaseModel): class CompletionResponse(BaseModel): """Completion response.""" - completion_message: CompletionMessage + content: str + stop_reason: StopReason logprobs: Optional[List[TokenLogProbs]] = None @@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel): class BatchCompletionResponse(BaseModel): """Batch completion response.""" - completion_message_batch: List[CompletionMessage] + batch: List[CompletionResponse] @json_schema_type @@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel): @json_schema_type class BatchChatCompletionResponse(BaseModel): - completion_message_batch: List[CompletionMessage] + batch: List[ChatCompletionResponse] @json_schema_type @@ -181,10 +182,8 @@ class ModelStore(Protocol): class Inference(Protocol): model_store: ModelStore - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/completion") - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -196,7 +195,7 @@ class Inference(Protocol): # This method is not `async def` because it can result in either an # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index cf62da1d0..a78e808d0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -70,7 +70,7 @@ class InferenceRouter(Inference): async def register_model(self, model: ModelDef) -> None: await self.routing_table.register_model(model) - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -93,11 +93,11 @@ class InferenceRouter(Inference): ) provider = self.routing_table.get_provider_impl(model) if stream: - return (chunk async for chunk in provider.chat_completion(**params)) + return (chunk async for chunk in await provider.chat_completion(**params)) else: - return provider.chat_completion(**params) + return await provider.chat_completion(**params) - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -114,9 +114,9 @@ class InferenceRouter(Inference): logprobs=logprobs, ) if stream: - return (chunk async for chunk in provider.completion(**params)) + return (chunk async for chunk in await provider.completion(**params)) else: - return provider.completion(**params) + return await provider.completion(**params) async def embeddings( self, diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 22f87ef6b..8440ecc20 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: self.client.close() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) return tool_config - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 141051186..9f50ad227 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -48,7 +48,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -58,7 +58,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -84,7 +84,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index c82012cba..537f3a6b4 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -87,7 +87,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c50c869fd..3a3e4b451 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -84,7 +84,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return ret - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -94,7 +94,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -118,7 +118,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) def _get_params(self, request: ChatCompletionRequest) -> dict: return { diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index cd0afad0c..3c610099c 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 750ca126e..8c73d75ec 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -64,7 +64,7 @@ class TogetherInferenceAdapter( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ class TogetherInferenceAdapter( if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Together diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0d334fdad..cbc7490fd 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin): stop_reason = None with tracing.span("inference"): - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=self._get_tools(), diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 5a209d0b7..8b3ece978 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents): session_id=session_id, ) - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 20a8addc7..9ca128176 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3.api.datatypes import ( - InterleavedTextMedia, - Message, - ToolPromptFormat, -) from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model from pydantic import BaseModel from termcolor import cprint +from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig @@ -297,15 +296,12 @@ class Llama: if all(eos_reached): break - def text_completion( + def completion( self, - content: InterleavedTextMedia, - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - echo: bool = False, + request: CompletionRequest, ) -> Generator: + sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -313,26 +309,25 @@ class Llama: ): max_gen_len = self.model.params.max_seq_len - 1 - model_input = self.formatter.encode_content(content) - + model_input = self.formatter.encode_content(request.content) yield from self.generate( model_input=model_input, max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - echo=echo, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), + include_stop_token=True, + echo=False, ) def chat_completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: ChatCompletionRequest, ) -> Generator: + messages = chat_completion_request_to_messages(request) + + sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -343,12 +338,12 @@ class Llama: yield from self.generate( model_input=self.formatter.encode_dialog_prompt( messages, - tool_prompt_format, + request.tool_prompt_format, ), max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), include_stop_token=True, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 7edc279d0..34053343e 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -13,9 +13,6 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, -) from .config import MetaReferenceInferenceConfig from .generation import Llama @@ -58,7 +55,18 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): if self.config.create_distributed_process_group: self.generator.stop() - def completion( + def check_model(self, request) -> None: + model = resolve_model(request.model) + if model is None: + raise RuntimeError( + f"Unknown model: {request.model}, Run `llama model list`" + ) + elif model.descriptor() != self.model.descriptor(): + raise RuntimeError( + f"Model mismatch: {request.model} != {self.model.descriptor()}" + ) + + async def completion( self, model: str, content: InterleavedTextMedia, @@ -66,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - def chat_completion( + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + self.check_model(request) + + if request.stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + def impl(): + stop_reason = None + + for token_result in self.generator.completion(request): + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + logprobs = None + if stop_reason is None: + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs = [ + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ] + + yield CompletionResponseStreamChunk( + delta=text, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if stop_reason is None: + yield CompletionResponseStreamChunk( + delta="", + stop_reason=StopReason.out_of_tokens, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + for x in impl(): + yield x + else: + for x in impl(): + yield x + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + def impl(): + tokens = [] + logprobs = [] + stop_reason = None + + tokenizer = self.generator.formatter.tokenizer + for token_result in self.generator.completion(request): + tokens.append(token_result.token) + + if token_result.token in tokenizer.stop_tokens: + # not quite right semantically + stop_reason = StopReason.end_of_turn + + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + content = self.generator.formatter.tokenizer.decode(tokens) + return CompletionResponse( + content=content, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + return impl() + else: + return impl() + + async def chat_completion( self, model: str, messages: List[Message], @@ -93,16 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stream=stream, logprobs=logprobs, ) - - model = resolve_model(request.model) - if model is None: - raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" - ) - elif model.descriptor() != self.model.descriptor(): - raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" - ) + self.check_model(request) if self.config.create_distributed_process_group: if SEMAPHORE.locked(): @@ -111,26 +215,17 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): if request.stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: def impl(): - messages = chat_completion_request_to_messages(request) - tokens = [] logprobs = [] stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if token_result.text == "<|eot_id|>": @@ -170,8 +265,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): self, request: ChatCompletionRequest ) -> AsyncGenerator: def impl(): - messages = chat_completion_request_to_messages(request) - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -184,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): stop_reason = None ipython = False - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if not ipython and token_result.text.startswith("<|python_tag|>"): diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index e8f483f30..7e7831185 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -7,16 +7,17 @@ import os from copy import deepcopy from functools import partial -from typing import Generator, List, Optional +from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from .config import MetaReferenceInferenceConfig from .generation import Llama, model_checkpoint_dir -from .parallel_utils import InferenceArgs, ModelParallelProcessGroup +from .parallel_utils import ModelParallelProcessGroup class ModelRunner: @@ -24,15 +25,13 @@ class ModelRunner: self.llama = llama # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` - def __call__(self, task: InferenceArgs): - return self.llama.chat_completion( - task.messages, - task.temperature, - task.top_p, - task.max_gen_len, - task.logprobs, - task.tool_prompt_format, - ) + def __call__(self, req: Any): + if isinstance(req, ChatCompletionRequest): + return self.llama.chat_completion(req) + elif isinstance(req, CompletionRequest): + return self.llama.completion(req) + else: + raise ValueError(f"Unexpected task type {type(req)}") def init_model_cb(config: MetaReferenceInferenceConfig): @@ -77,23 +76,18 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - def chat_completion( + def completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: CompletionRequest, ) -> Generator: - req_obj = InferenceArgs( - messages=deepcopy(messages), - temperature=temperature, - top_p=top_p, - max_gen_len=max_gen_len, - logprobs=logprobs or False, - tool_prompt_format=tool_prompt_format, - ) - + req_obj = deepcopy(request) + gen = self.group.run_inference(req_obj) + yield from gen + + def chat_completion( + self, + request: ChatCompletionRequest, + ) -> Generator: + req_obj = deepcopy(request) gen = self.group.run_inference(req_obj) yield from gen diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index 7dbedd0f0..62eeefaac 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -4,6 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import json import multiprocessing import os @@ -11,10 +17,9 @@ import tempfile import time import uuid from enum import Enum -from typing import Callable, Generator, List, Literal, Optional, Union +from typing import Callable, Generator, Literal, Optional, Union import torch - import zmq from fairscale.nn.model_parallel.initialize import ( @@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat - from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from .generation import TokenResult -class InferenceArgs(BaseModel): - messages: List[Message] - temperature: float - top_p: float - max_gen_len: int - logprobs: bool - tool_prompt_format: ToolPromptFormat - - class ProcessingMessageName(str, Enum): ready_request = "ready_request" ready_response = "ready_response" @@ -80,7 +76,7 @@ class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ( ProcessingMessageName.task_request ) - task: InferenceArgs + task: Union[CompletionRequest, ChatCompletionRequest] class TaskResponse(BaseModel): @@ -349,11 +345,13 @@ class ModelParallelProcessGroup: self.process.join() self.started = False - def run_inference(self, inference_args: InferenceArgs) -> Generator: + def run_inference( + self, req: Union[CompletionRequest, ChatCompletionRequest] + ) -> Generator: assert not self.running, "inference already running" self.running = True - self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) + self.request_socket.send(encode_msg(TaskRequest(task=req))) try: while True: obj_json = self.request_socket.recv() diff --git a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py index a6f450fae..99b1c29be 100644 --- a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py @@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase): # TODO: llama-stack inference protocol has issues with non-streaming inference code content = "" - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( model=self.model, messages=[shield_input_message], stream=True, diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index 5cdb1a2ab..c977c738d 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if self.engine: self.engine.shutdown_background_loop() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): logprobs=logprobs, ) - def chat_completion( + async def chat_completion( self, model: str, messages: list[Message], @@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if stream: return self._stream_chat_completion(request, results_generator) else: - return self._nonstream_chat_completion(request, results_generator) + return await self._nonstream_chat_completion(request, results_generator) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 6774d3f1f..9c34c3a28 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages): ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 581a0d428..09d6a69db 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -126,6 +126,45 @@ async def test_model_list(inference_settings): assert model_def.identifier == params["model"] +@pytest.mark.asyncio +async def test_completion(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_id__ != "meta-reference": + pytest.skip("Other inference providers don't support completion() yet") + + response = await inference_impl.completion( + content="Roses are red,", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + + assert isinstance(response, CompletionResponse) + assert "violets are blue" in response.content + + chunks = [ + r + async for r in await inference_impl.completion( + content="Roses are red,", + stream=True, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + ] + + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert len(chunks) == 51 + last = chunks[-1] + assert last.stop_reason == StopReason.out_of_tokens + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] @@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=sample_messages, stream=True, **inference_settings["common_params"], @@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming( response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=messages, tools=[sample_tool_definition], stream=True, From 8cfbb9d38b80ca5930e8ba20756cdaa51af30ca0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 19 Oct 2024 17:19:54 -0700 Subject: [PATCH 06/10] Improve an important error message --- .../distribution/routers/routing_tables.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ede30aea1..597dbed07 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -87,8 +87,21 @@ class CommonRoutingTableImpl(RoutingTable): def get_provider_impl( self, routing_key: str, provider_id: Optional[str] = None ) -> Any: + def apiname_object(): + if isinstance(self, ModelsRoutingTable): + return ("Inference", "model") + elif isinstance(self, ShieldsRoutingTable): + return ("Safety", "shield") + elif isinstance(self, MemoryBanksRoutingTable): + return ("Memory", "memory_bank") + else: + raise ValueError("Unknown routing table type") + if routing_key not in self.registry: - raise ValueError(f"`{routing_key}` not registered") + apiname, objname = apiname_object() + raise ValueError( + f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}." + ) objs = self.registry[routing_key] for obj in objs: From 59c43736e83c2b3b9726441a28bab09e4d92d52f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 19 Oct 2024 17:26:18 -0700 Subject: [PATCH 07/10] update ollama for llama-guard3 --- llama_stack/providers/adapters/inference/ollama/ollama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 3a3e4b451..74aed6e5e 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -33,7 +33,8 @@ OLLAMA_SUPPORTED_MODELS = { "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", - "Llama-Guard-3-8B": "xe/llamaguard3:latest", + "Llama-Guard-3-8B": "llama-guard3:8b", + "Llama-Guard-3-1B": "llama-guard3:1b", } From a27a2cd2af93d73e48d9789ac92c55927da6c44d Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Sun, 20 Oct 2024 21:43:25 -0400 Subject: [PATCH 08/10] Add vLLM inference provider for OpenAI compatible vLLM server (#178) This PR adds vLLM inference provider for OpenAI compatible vLLM server. --- .../build_configs/local-vllm-build.yaml | 2 +- .../templates/remote-vllm-build.yaml | 10 ++ .../adapters/inference/vllm/__init__.py | 15 ++ .../adapters/inference/vllm/config.py | 22 +++ .../providers/adapters/inference/vllm/vllm.py | 152 ++++++++++++++++++ llama_stack/providers/registry/inference.py | 9 ++ 6 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 llama_stack/distribution/templates/remote-vllm-build.yaml create mode 100644 llama_stack/providers/adapters/inference/vllm/__init__.py create mode 100644 llama_stack/providers/adapters/inference/vllm/config.py create mode 100644 llama_stack/providers/adapters/inference/vllm/vllm.py diff --git a/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml b/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml index e907cb7c9..e333a137b 100644 --- a/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml +++ b/llama_stack/distribution/templates/build_configs/local-vllm-build.yaml @@ -7,4 +7,4 @@ distribution_spec: safety: meta-reference agents: meta-reference telemetry: meta-reference -image_type: conda +image_type: conda \ No newline at end of file diff --git a/llama_stack/distribution/templates/remote-vllm-build.yaml b/llama_stack/distribution/templates/remote-vllm-build.yaml new file mode 100644 index 000000000..525c3a930 --- /dev/null +++ b/llama_stack/distribution/templates/remote-vllm-build.yaml @@ -0,0 +1,10 @@ +name: remote-vllm +distribution_spec: + description: Use remote vLLM for running LLM inference + providers: + inference: remote::vllm + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: docker \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/vllm/__init__.py b/llama_stack/providers/adapters/inference/vllm/__init__.py new file mode 100644 index 000000000..f4588a307 --- /dev/null +++ b/llama_stack/providers/adapters/inference/vllm/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VLLMImplConfig +from .vllm import VLLMInferenceAdapter + + +async def get_adapter_impl(config: VLLMImplConfig, _deps): + assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" + impl = VLLMInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/inference/vllm/config.py b/llama_stack/providers/adapters/inference/vllm/config.py new file mode 100644 index 000000000..65815922c --- /dev/null +++ b/llama_stack/providers/adapters/inference/vllm/config.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class VLLMImplConfig(BaseModel): + url: Optional[str] = Field( + default=None, + description="The URL for the vLLM model serving endpoint", + ) + api_token: Optional[str] = Field( + default=None, + description="The API token", + ) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py new file mode 100644 index 000000000..a5934928a --- /dev/null +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import AsyncGenerator + +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from openai import OpenAI + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelsProtocolPrivate + +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) + +from .config import VLLMImplConfig + +VLLM_SUPPORTED_MODELS = { + "Llama3.1-8B": "meta-llama/Llama-3.1-8B", + "Llama3.1-70B": "meta-llama/Llama-3.1-70B", + "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", + "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", + "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", + "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", + "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", + "Llama3.2-1B": "meta-llama/Llama-3.2-1B", + "Llama3.2-3B": "meta-llama/Llama-3.2-3B", + "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", + "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", + "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", + "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", + "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", + "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", + "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", + "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", + "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", + "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", + "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", + "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", +} + + +class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): + def __init__(self, config: VLLMImplConfig) -> None: + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + self.client = None + + async def initialize(self) -> None: + self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + + async def register_model(self, model: ModelDef) -> None: + raise ValueError("Model registration is not supported for vLLM models") + + async def shutdown(self) -> None: + pass + + async def list_models(self) -> List[ModelDef]: + return [ + ModelDef(identifier=model.id, llama_model=model.id) + for model in self.client.models.list() + ] + + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + raise NotImplementedError() + + def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_chat_completion(request, self.client) + else: + return self._nonstream_chat_completion(request, self.client) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.completions.create(**params) + return process_chat_completion_response(request, r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: + params = self._get_params(request) + + # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async + # generator so this wrapper is not necessary? + async def _to_async_generator(): + s = client.completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "model": VLLM_SUPPORTED_MODELS[request.model], + "prompt": chat_completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **get_sampling_options(request), + } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 686fc273b..c3370bfd9 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -60,6 +60,15 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.ollama", ), ), +# remote_provider_spec( +# api=Api.inference, +# adapter=AdapterSpec( +# adapter_type="vllm", +# pip_packages=["openai"], +# module="llama_stack.providers.adapters.inference.vllm", +# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", +# ), +# ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( From cae5b0708b161658646a1971ab88ecdaa18ad488 Mon Sep 17 00:00:00 2001 From: raghotham Date: Mon, 21 Oct 2024 11:48:19 +0530 Subject: [PATCH 09/10] Create .readthedocs.yaml Trying out readthedocs --- .readthedocs.yaml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..f89fc906d --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,32 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # You can also specify other tool versions: + # nodejs: "19" + # rust: "1.64" + # golang: "1.19" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +# python: +# install: +# - requirements: docs/requirements.txt From c995219731ec6961e03e096b5829d6b7a38980d7 Mon Sep 17 00:00:00 2001 From: nehal-a2z Date: Mon, 21 Oct 2024 23:16:53 +0530 Subject: [PATCH 10/10] Update event_logger.py (#275) spelling error --- llama_stack/apis/agents/event_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index b5ad6ae91..25931b821 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -180,5 +180,5 @@ class EventLogger: color="cyan", ) - preivous_event_type = event_type + previous_event_type = event_type previous_step_type = step_type