forked from phoenix-oss/llama-stack-mirror
Inference to use provider resource id to register and validate (#428)
This PR changes the way model id gets translated to the final model name that gets passed through the provider. Major changes include: 1) Providers are responsible for registering an object and as part of the registration returning the object with the correct provider specific name of the model provider_resource_id 2) To help with the common look ups different names a new ModelLookup class is created. Tested all inference providers including together, fireworks, vllm, ollama, meta reference and bedrock
This commit is contained in:
parent
e51107e019
commit
fdff24e77a
21 changed files with 460 additions and 290 deletions
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -2856,7 +2856,7 @@
|
||||||
"ChatCompletionRequest": {
|
"ChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"messages": {
|
"messages": {
|
||||||
|
@ -2993,7 +2993,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"messages"
|
"messages"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -3120,7 +3120,7 @@
|
||||||
"CompletionRequest": {
|
"CompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"content": {
|
"content": {
|
||||||
|
@ -3249,7 +3249,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"content"
|
"content"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -4552,7 +4552,7 @@
|
||||||
"EmbeddingsRequest": {
|
"EmbeddingsRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"contents": {
|
"contents": {
|
||||||
|
@ -4584,7 +4584,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"contents"
|
"contents"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -7837,34 +7837,10 @@
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
{
|
||||||
"name": "MemoryBanks"
|
"name": "Safety"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BatchInference"
|
"name": "EvalTasks"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Agents"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Inference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "DatasetIO"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Eval"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Models"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "PostTraining"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "ScoringFunctions"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Datasets"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Shields"
|
"name": "Shields"
|
||||||
|
@ -7872,15 +7848,6 @@
|
||||||
{
|
{
|
||||||
"name": "Telemetry"
|
"name": "Telemetry"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "Inspect"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Safety"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "SyntheticDataGeneration"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Memory"
|
"name": "Memory"
|
||||||
},
|
},
|
||||||
|
@ -7888,7 +7855,40 @@
|
||||||
"name": "Scoring"
|
"name": "Scoring"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "EvalTasks"
|
"name": "ScoringFunctions"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "SyntheticDataGeneration"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Models"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Agents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MemoryBanks"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "DatasetIO"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Datasets"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "PostTraining"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "BatchInference"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Eval"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Inspect"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
|
|
|
@ -396,7 +396,7 @@ components:
|
||||||
- $ref: '#/components/schemas/ToolResponseMessage'
|
- $ref: '#/components/schemas/ToolResponseMessage'
|
||||||
- $ref: '#/components/schemas/CompletionMessage'
|
- $ref: '#/components/schemas/CompletionMessage'
|
||||||
type: array
|
type: array
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
response_format:
|
response_format:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
@ -453,7 +453,7 @@ components:
|
||||||
$ref: '#/components/schemas/ToolDefinition'
|
$ref: '#/components/schemas/ToolDefinition'
|
||||||
type: array
|
type: array
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- messages
|
- messages
|
||||||
type: object
|
type: object
|
||||||
ChatCompletionResponse:
|
ChatCompletionResponse:
|
||||||
|
@ -577,7 +577,7 @@ components:
|
||||||
default: 0
|
default: 0
|
||||||
type: integer
|
type: integer
|
||||||
type: object
|
type: object
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
response_format:
|
response_format:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
@ -626,7 +626,7 @@ components:
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- content
|
- content
|
||||||
type: object
|
type: object
|
||||||
CompletionResponse:
|
CompletionResponse:
|
||||||
|
@ -903,10 +903,10 @@ components:
|
||||||
- $ref: '#/components/schemas/ImageMedia'
|
- $ref: '#/components/schemas/ImageMedia'
|
||||||
type: array
|
type: array
|
||||||
type: array
|
type: array
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- contents
|
- contents
|
||||||
type: object
|
type: object
|
||||||
EmbeddingsResponse:
|
EmbeddingsResponse:
|
||||||
|
@ -3384,7 +3384,7 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
|
\ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
@ -4748,24 +4748,24 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
tags:
|
tags:
|
||||||
- name: MemoryBanks
|
- name: Safety
|
||||||
- name: BatchInference
|
- name: EvalTasks
|
||||||
- name: Agents
|
|
||||||
- name: Inference
|
|
||||||
- name: DatasetIO
|
|
||||||
- name: Eval
|
|
||||||
- name: Models
|
|
||||||
- name: PostTraining
|
|
||||||
- name: ScoringFunctions
|
|
||||||
- name: Datasets
|
|
||||||
- name: Shields
|
- name: Shields
|
||||||
- name: Telemetry
|
- name: Telemetry
|
||||||
- name: Inspect
|
|
||||||
- name: Safety
|
|
||||||
- name: SyntheticDataGeneration
|
|
||||||
- name: Memory
|
- name: Memory
|
||||||
- name: Scoring
|
- name: Scoring
|
||||||
- name: EvalTasks
|
- name: ScoringFunctions
|
||||||
|
- name: SyntheticDataGeneration
|
||||||
|
- name: Models
|
||||||
|
- name: Agents
|
||||||
|
- name: MemoryBanks
|
||||||
|
- name: DatasetIO
|
||||||
|
- name: Inference
|
||||||
|
- name: Datasets
|
||||||
|
- name: PostTraining
|
||||||
|
- name: BatchInference
|
||||||
|
- name: Eval
|
||||||
|
- name: Inspect
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
|
||||||
|
|
|
@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c
|
||||||
$ curl http://localhost:5000/inference/chat_completion \
|
$ curl http://localhost:5000/inference/chat_completion \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "Llama3.1-8B-Instruct",
|
"model_id": "Llama3.1-8B-Instruct",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}
|
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}
|
||||||
|
|
|
@ -226,7 +226,7 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -237,7 +237,7 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
|
@ -254,6 +254,6 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/embeddings")
|
@webmethod(route="/inference/embeddings")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
|
@ -95,7 +95,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -106,7 +106,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -116,7 +116,7 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
|
@ -124,16 +124,16 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model,
|
model_id=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -147,11 +147,11 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model=model,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
|
||||||
return p.__provider_spec__.api
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
if obj.provider_id == "remote":
|
if obj.provider_id == "remote":
|
||||||
|
@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
obj.provider_id = ""
|
obj.provider_id = ""
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
await p.register_shield(obj)
|
await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
|
@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
assert len(objects) == 1
|
assert len(objects) == 1
|
||||||
return objects[0]
|
return objects[0]
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
async def register_object(
|
||||||
|
self, obj: RoutableObjectWithProvider
|
||||||
|
) -> RoutableObjectWithProvider:
|
||||||
# Get existing objects from registry
|
# Get existing objects from registry
|
||||||
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
||||||
|
|
||||||
|
@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
print(
|
print(
|
||||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||||
)
|
)
|
||||||
return
|
return existing_obj
|
||||||
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
await self.dist_registry.register(obj)
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
|
@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return model
|
return registered_model
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
@ -150,7 +150,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
messages += input_messages
|
messages += input_messages
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model=candidate.model,
|
model_id=candidate.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=candidate.sampling_params,
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -86,6 +86,7 @@ class Llama:
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
llama_model = model.core_model_id.value
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
@ -186,13 +187,20 @@ class Llama:
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args)
|
return Llama(model, tokenizer, model_args, llama_model)
|
||||||
|
|
||||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Transformer,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
args: ModelArgs,
|
||||||
|
llama_model: str,
|
||||||
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
self.llama_model = llama_model
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -369,7 +377,7 @@ class Llama:
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
|
|
|
@ -11,9 +11,11 @@ from typing import AsyncGenerator, List
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
|
@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
[
|
||||||
|
build_model_alias(
|
||||||
|
model.descriptor(),
|
||||||
|
model.core_model_id.value,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config)
|
self.generator = Llama.build(self.config)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
if model.identifier != self.model.descriptor():
|
|
||||||
raise ValueError(
|
|
||||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -110,7 +110,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -120,7 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("vLLM completion")
|
log.info("vLLM completion")
|
||||||
messages = [UserMessage(content=content)]
|
messages = [UserMessage(content=content)]
|
||||||
return self.chat_completion(
|
return self.chat_completion(
|
||||||
model=model,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -129,7 +129,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
@ -144,7 +144,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
assert self.engine is not None
|
assert self.engine is not None
|
||||||
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -215,7 +215,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self, model: str, contents: list[InterleavedTextMedia]
|
self, model_id: str, contents: list[InterleavedTextMedia]
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
log.info("vLLM embeddings")
|
log.info("vLLM embeddings")
|
||||||
# TODO
|
# TODO
|
||||||
|
|
|
@ -62,7 +62,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
||||||
)
|
)
|
||||||
|
|
||||||
judge_response = await self.inference_api.chat_completion(
|
judge_response = await self.inference_api.chat_completion(
|
||||||
model=fn_def.params.judge_model,
|
model_id=fn_def.params.judge_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
|
@ -7,11 +7,15 @@
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
@ -19,19 +23,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
|
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
model_aliases = [
|
||||||
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
build_model_alias(
|
||||||
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
}
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this is not quite tested after the recent refactors
|
# NOTE: this is not quite tested after the recent refactors
|
||||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = create_bedrock_client(config)
|
self._client = create_bedrock_client(config)
|
||||||
|
@ -49,7 +60,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -286,7 +297,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -298,8 +309,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -404,7 +416,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||||
bedrock_model = self.map_to_provider_model(request.model)
|
bedrock_model = request.model
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
request.sampling_params
|
request.sampling_params
|
||||||
)
|
)
|
||||||
|
@ -433,7 +445,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
@ -15,7 +17,10 @@ from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -28,16 +33,23 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
||||||
DATABRICKS_SUPPORTED_MODELS = {
|
model_aliases = [
|
||||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
build_model_alias(
|
||||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
"databricks-meta-llama-3-1-70b-instruct",
|
||||||
}
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"databricks-meta-llama-3-1-405b-instruct",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
self,
|
||||||
|
model_aliases=model_aliases,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
@ -113,8 +125,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": request.model,
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
"prompt": chat_completion_request_to_prompt(
|
||||||
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
|
),
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,14 +7,17 @@
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -31,25 +34,52 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
FIREWORKS_SUPPORTED_MODELS = {
|
|
||||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
model_aliases = [
|
||||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
build_model_alias(
|
||||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
"fireworks/llama-v3p1-8b-instruct",
|
||||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
),
|
||||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
build_model_alias(
|
||||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
"fireworks/llama-v3p1-70b-instruct",
|
||||||
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
}
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-v3p1-405b-instruct",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-v3p2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-v3p2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-v3p2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-v3p2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-guard-3-8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"fireworks/llama-guard-3-11b-vision",
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(
|
class FireworksInferenceAdapter(
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -74,15 +104,16 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -138,7 +169,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
@ -148,8 +179,9 @@ class FireworksInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -207,7 +239,7 @@ class FireworksInferenceAdapter(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -221,7 +253,7 @@ class FireworksInferenceAdapter(
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
|
@ -229,7 +261,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -7,15 +7,20 @@
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
@ -33,19 +38,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
OLLAMA_SUPPORTED_MODELS = {
|
|
||||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
model_aliases = [
|
||||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
build_model_alias(
|
||||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
"Llama-Guard-3-8B": "llama-guard3:8b",
|
),
|
||||||
"Llama-Guard-3-1B": "llama-guard3:1b",
|
build_model_alias(
|
||||||
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
|
"llama3.1:70b-instruct-fp16",
|
||||||
}
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3.2:1b-instruct-fp16",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3.2:3b-instruct-fp16",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-guard3:8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama-guard3:1b",
|
||||||
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"x/llama3.2-vision:11b-instruct-fp16",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
model_aliases=model_aliases,
|
||||||
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -65,44 +96,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
|
||||||
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
|
|
||||||
|
|
||||||
async def list_models(self) -> List[Model]:
|
|
||||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
res = await self.client.ps()
|
|
||||||
for r in res["models"]:
|
|
||||||
if r["model"] not in ollama_to_llama:
|
|
||||||
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
llama_model = ollama_to_llama[r["model"]]
|
|
||||||
print(f"Found model {llama_model} in Ollama")
|
|
||||||
ret.append(
|
|
||||||
Model(
|
|
||||||
identifier=llama_model,
|
|
||||||
metadata={
|
|
||||||
"ollama_model": r["model"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -148,7 +153,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -158,8 +163,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
print(f"model={model}")
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -197,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -207,7 +214,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"options": sampling_options,
|
"options": sampling_options,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
|
@ -271,7 +278,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
@ -15,7 +17,10 @@ from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -33,25 +38,47 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
TOGETHER_SUPPORTED_MODELS = {
|
model_aliases = [
|
||||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
build_model_alias(
|
||||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
),
|
||||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
build_model_alias(
|
||||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
),
|
||||||
}
|
build_model_alias(
|
||||||
|
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(
|
class TogetherInferenceAdapter(
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
):
|
):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases)
|
||||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -63,15 +90,16 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -135,7 +163,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
@ -145,8 +173,9 @@ class TogetherInferenceAdapter(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -204,7 +233,7 @@ class TogetherInferenceAdapter(
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -213,7 +242,7 @@ class TogetherInferenceAdapter(
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
|
@ -221,7 +250,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,13 +8,17 @@ from typing import AsyncGenerator
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models, resolve_model
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -30,44 +34,36 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
def build_model_aliases():
|
||||||
|
return [
|
||||||
|
build_model_alias(
|
||||||
|
model.huggingface_repo,
|
||||||
|
model.descriptor(),
|
||||||
|
)
|
||||||
|
for model in all_registered_models()
|
||||||
|
if model.huggingface_repo
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
model_aliases=build_model_aliases(),
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.client = None
|
self.client = None
|
||||||
self.huggingface_repo_to_llama_model_id = {
|
|
||||||
model.huggingface_repo: model.descriptor()
|
|
||||||
for model in all_registered_models()
|
|
||||||
if model.huggingface_repo
|
|
||||||
}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
for running_model in self.client.models.list():
|
|
||||||
repo = running_model.id
|
|
||||||
if repo not in self.huggingface_repo_to_llama_model_id:
|
|
||||||
print(f"Unknown model served by vllm: {repo}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
|
||||||
if identifier == model.provider_resource_id:
|
|
||||||
print(
|
|
||||||
f"Verified that model {model.provider_resource_id} is being served by vLLM"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {model.provider_resource_id} is not being served by vLLM"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -78,7 +74,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -88,8 +84,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -141,10 +138,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
model = resolve_model(request.model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Unknown model: {request.model}")
|
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
@ -156,16 +149,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = completion_request_to_prompt(
|
||||||
|
request,
|
||||||
|
self.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": model.huggingface_repo,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**options,
|
**options,
|
||||||
|
@ -173,7 +170,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -49,7 +49,7 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=f"meta-reference-{i}",
|
provider_id=f"meta-reference-{i}",
|
||||||
provider_type="meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
config=MetaReferenceInferenceConfig(
|
config=MetaReferenceInferenceConfig(
|
||||||
model=m,
|
model=m,
|
||||||
max_seq_len=4096,
|
max_seq_len=4096,
|
||||||
|
@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_short_name(model_name: str) -> str:
|
||||||
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Full model name like "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Short name like "llama_8b" suitable for test markers
|
||||||
|
"""
|
||||||
|
model_name = model_name.lower()
|
||||||
|
if "vision" in model_name:
|
||||||
|
return "llama_vision"
|
||||||
|
elif "3b" in model_name:
|
||||||
|
return "llama_3b"
|
||||||
|
elif "8b" in model_name:
|
||||||
|
return "llama_8b"
|
||||||
|
else:
|
||||||
|
return model_name.replace(".", "_").replace("-", "_")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def model_id(inference_model) -> str:
|
||||||
|
return get_model_short_name(inference_model)
|
||||||
|
|
||||||
|
|
||||||
INFERENCE_FIXTURES = [
|
INFERENCE_FIXTURES = [
|
||||||
"meta_reference",
|
"meta_reference",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TestInference:
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content="Micheael Jordan is born in ",
|
content="Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -110,7 +110,7 @@ class TestInference:
|
||||||
async for r in await inference_impl.completion(
|
async for r in await inference_impl.completion(
|
||||||
content="Roses are red,",
|
content="Roses are red,",
|
||||||
stream=True,
|
stream=True,
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -171,7 +171,7 @@ class TestInference:
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -204,7 +204,7 @@ class TestInference:
|
||||||
num_seasons_in_nba: int
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -227,7 +227,7 @@ class TestInference:
|
||||||
assert answer.num_seasons_in_nba == 15
|
assert answer.num_seasons_in_nba == 15
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -250,7 +250,7 @@ class TestInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -286,7 +286,7 @@ class TestInference:
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -327,7 +327,7 @@ class TestInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -4,32 +4,61 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from collections import namedtuple
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||||
|
for model in all_registered_models():
|
||||||
|
if model.descriptor() == model_descriptor:
|
||||||
|
return model.huggingface_repo
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||||
|
return ModelAlias(
|
||||||
|
provider_model_id=provider_model_id,
|
||||||
|
aliases=[
|
||||||
|
model_descriptor,
|
||||||
|
get_huggingface_repo(model_descriptor),
|
||||||
|
],
|
||||||
|
llama_model=model_descriptor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
|
def __init__(self, model_aliases: List[ModelAlias]):
|
||||||
|
self.alias_to_provider_id_map = {}
|
||||||
|
self.provider_id_to_llama_model_map = {}
|
||||||
|
for alias_obj in model_aliases:
|
||||||
|
for alias in alias_obj.aliases:
|
||||||
|
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||||
|
# also add a mapping from provider model id to itself for easy lookup
|
||||||
|
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||||
|
alias_obj.provider_model_id
|
||||||
|
)
|
||||||
|
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||||
|
alias_obj.llama_model
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
|
def get_provider_model_id(self, identifier: str) -> str:
|
||||||
self.stack_to_provider_models_map = stack_to_provider_models_map
|
if identifier in self.alias_to_provider_id_map:
|
||||||
|
return self.alias_to_provider_id_map[identifier]
|
||||||
def map_to_provider_model(self, identifier: str) -> str:
|
else:
|
||||||
model = resolve_model(identifier)
|
|
||||||
if not model:
|
|
||||||
raise ValueError(f"Unknown model: `{identifier}`")
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
if identifier not in self.stack_to_provider_models_map:
|
def get_llama_model(self, provider_model_id: str) -> str:
|
||||||
raise ValueError(
|
return self.provider_id_to_llama_model_map[provider_model_id]
|
||||||
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.stack_to_provider_models_map[identifier]
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model.provider_resource_id = self.get_provider_model_id(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
return model
|
||||||
if model.identifier not in self.stack_to_provider_models_map:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
|
||||||
)
|
|
||||||
|
|
|
@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_prompt(
|
def chat_completion_request_to_prompt(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_model_input_info(
|
def chat_completion_request_to_model_input_info(
|
||||||
request: ChatCompletionRequest, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
model_input = formatter.encode_dialog_prompt(messages)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
|
@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info(
|
||||||
|
|
||||||
def chat_completion_request_to_messages(
|
def chat_completion_request_to_messages(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
llama_model: str,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
"""Reads chat completion request and augments the messages to handle tools.
|
"""Reads chat completion request and augments the messages to handle tools.
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||||
add user messsage for custom tools, etc.
|
add user messsage for custom tools, etc.
|
||||||
"""
|
"""
|
||||||
model = resolve_model(request.model)
|
model = resolve_model(llama_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
cprint(f"Could not resolve model {request.model}", color="red")
|
cprint(f"Could not resolve model {llama_model}", color="red")
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
if model.descriptor() not in supported_inference_models():
|
if model.descriptor() not in supported_inference_models():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue