mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	# What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation)
		
			
				
	
	
		
			189 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Swift
		
	
	
	
	
	
			
		
		
	
	
			189 lines
		
	
	
	
		
			6.4 KiB
		
	
	
	
		
			Swift
		
	
	
	
	
	
| import Foundation
 | |
| 
 | |
| import LLaMARunner
 | |
| import LlamaStackClient
 | |
| 
 | |
| class RunnerHolder: ObservableObject {
 | |
|   var runner: Runner?
 | |
| }
 | |
| 
 | |
| public class LocalInference: Inference {
 | |
|   private var runnerHolder = RunnerHolder()
 | |
|   private let runnerQueue: DispatchQueue
 | |
| 
 | |
|   public init (queue: DispatchQueue) {
 | |
|     runnerQueue = queue
 | |
|   }
 | |
| 
 | |
|   public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
 | |
|     runnerHolder.runner = runnerHolder.runner ?? Runner(
 | |
|       modelPath: modelPath,
 | |
|       tokenizerPath: tokenizerPath
 | |
|     )
 | |
| 
 | |
| 
 | |
|     runnerQueue.async {
 | |
|       let runner = self.runnerHolder.runner
 | |
|       do {
 | |
|         try runner!.load()
 | |
|         completion(.success(()))
 | |
|       } catch let loadError {
 | |
|         print("error: " + loadError.localizedDescription)
 | |
|         completion(.failure(loadError))
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   public func stop() {
 | |
|     runnerHolder.runner?.stop()
 | |
|   }
 | |
| 
 | |
|   public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
 | |
|     return AsyncStream { continuation in
 | |
|       let workItem = DispatchWorkItem {
 | |
|         do {
 | |
|           var tokens: [String] = []
 | |
| 
 | |
|           let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
 | |
|           var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil
 | |
|           var buffer = ""
 | |
|           var ipython = false
 | |
|           var echoDropped = false
 | |
| 
 | |
|           try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
 | |
|             buffer += token
 | |
| 
 | |
|             // HACK: Workaround until LlamaRunner exposes echo param
 | |
|             if (!echoDropped) {
 | |
|               if (buffer.hasPrefix(prompt)) {
 | |
|                 buffer = String(buffer.dropFirst(prompt.count))
 | |
|                 echoDropped = true
 | |
|               }
 | |
|               return
 | |
|             }
 | |
| 
 | |
|             tokens.append(token)
 | |
| 
 | |
|             if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
 | |
|               ipython = true
 | |
|               continuation.yield(
 | |
|                 Components.Schemas.ChatCompletionResponseStreamChunk(
 | |
|                   event: Components.Schemas.ChatCompletionResponseEvent(
 | |
|                     event_type: .progress,
 | |
|                     delta: .tool_call(Components.Schemas.ToolCallDelta(
 | |
|                       _type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
 | |
|                       tool_call: .case1(""),
 | |
|                       parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started
 | |
|                       )
 | |
|                     )
 | |
|                   )
 | |
|                 )
 | |
|               )
 | |
| 
 | |
|               if (buffer.starts(with: "<|python_tag|>")) {
 | |
|                 buffer = String(buffer.dropFirst("<|python_tag|>".count))
 | |
|               }
 | |
|             }
 | |
| 
 | |
|             // TODO: Non-streaming lobprobs
 | |
| 
 | |
|             var text = ""
 | |
|             if token == "<|eot_id|>" {
 | |
|               stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn
 | |
|             } else if token == "<|eom_id|>" {
 | |
|               stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
 | |
|             } else {
 | |
|               text = token
 | |
|             }
 | |
| 
 | |
|             var delta: Components.Schemas.ContentDelta
 | |
|             if ipython {
 | |
|               delta = .tool_call(Components.Schemas.ToolCallDelta(
 | |
|                 _type: .tool_call,
 | |
|                 tool_call: .case1(text),
 | |
|                 parse_status: .in_progress
 | |
|               ))
 | |
|             } else {
 | |
|               delta = .text(Components.Schemas.TextDelta(
 | |
|                 _type: Components.Schemas.TextDelta._typePayload.text,
 | |
|                 text: text
 | |
|                 )
 | |
|               )
 | |
|             }
 | |
| 
 | |
|             if stopReason == nil {
 | |
|               continuation.yield(
 | |
|                 Components.Schemas.ChatCompletionResponseStreamChunk(
 | |
|                   event: Components.Schemas.ChatCompletionResponseEvent(
 | |
|                     event_type: .progress,
 | |
|                     delta: delta
 | |
|                   )
 | |
|                 )
 | |
|               )
 | |
|             }
 | |
|           }
 | |
| 
 | |
|           if stopReason == nil {
 | |
|             stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens
 | |
|           }
 | |
| 
 | |
|           let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
 | |
|           // TODO: non-streaming support
 | |
| 
 | |
|           let didParseToolCalls = message.tool_calls?.count ?? 0 > 0
 | |
|           if ipython && !didParseToolCalls {
 | |
|             continuation.yield(
 | |
|               Components.Schemas.ChatCompletionResponseStreamChunk(
 | |
|                 event: Components.Schemas.ChatCompletionResponseEvent(
 | |
|                   event_type: .progress,
 | |
|                   delta: .tool_call(Components.Schemas.ToolCallDelta(
 | |
|                     _type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
 | |
|                     tool_call: .case1(""),
 | |
|                     parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed
 | |
|                     )
 | |
|                   )
 | |
|                 )
 | |
|                 // TODO: stopReason
 | |
|               )
 | |
|             )
 | |
|           }
 | |
| 
 | |
|           for toolCall in message.tool_calls! {
 | |
|             continuation.yield(
 | |
|               Components.Schemas.ChatCompletionResponseStreamChunk(
 | |
|                 event: Components.Schemas.ChatCompletionResponseEvent(
 | |
|                   event_type: .progress,
 | |
|                   delta: .tool_call(Components.Schemas.ToolCallDelta(
 | |
|                     _type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
 | |
|                     tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
 | |
|                     parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded
 | |
|                     )
 | |
|                   )
 | |
|                 )
 | |
|                 // TODO: stopReason
 | |
|               )
 | |
|             )
 | |
|           }
 | |
| 
 | |
|           continuation.yield(
 | |
|             Components.Schemas.ChatCompletionResponseStreamChunk(
 | |
|               event: Components.Schemas.ChatCompletionResponseEvent(
 | |
|                 event_type: .complete,
 | |
|                 delta: .text(Components.Schemas.TextDelta(
 | |
|                   _type: Components.Schemas.TextDelta._typePayload.text,
 | |
|                   text: ""
 | |
|                   )
 | |
|                 )
 | |
|               )
 | |
|               // TODO: stopReason
 | |
|             )
 | |
|           )
 | |
|         }
 | |
|         catch (let error) {
 | |
|           print("Inference error: " + error.localizedDescription)
 | |
|         }
 | |
|       }
 | |
|       runnerQueue.async(execute: workItem)
 | |
|     }
 | |
|   }
 | |
| }
 |