Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import CoreML
import Foundation
import Generation
import Models
import Tokenizers

@available(macOS 15.0, iOS 18.0, *)
@main
Expand Down Expand Up @@ -108,12 +109,13 @@ struct TransformersCLI: AsyncParsableCommand {
let compiledURL = try compile(at: url)
print("Loading model \(compiledURL)")
let model: LanguageModel
if let tokenizerFolder {
let tokenizerURL = URL(filePath: tokenizerFolder, directoryHint: .isDirectory)
if let tokenizerPath {
let tokenizerURL = URL(filePath: tokenizerPath, directoryHint: .isDirectory)
let tokenizer = try await AutoTokenizer.from(modelFolder: tokenizerURL)
model = try LanguageModel.loadCompiled(
url: compiledURL,
tokenizerFolder: tokenizerURL,
computeUnits: computeUnits.asMLComputeUnits
computeUnits: computeUnits.asMLComputeUnits,
tokenizer: tokenizer
)
} else {
model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)
Expand Down
86 changes: 35 additions & 51 deletions Sources/Generation/LogitsWarper/MinPLogitsWarper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,71 +44,55 @@ public struct MinPLogitsWarper: LogitsProcessor {
public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor {
// Algorithm (following transformers implementation):
// 1. Compute probabilities from logits
// 2. Find max probability per batch
// 3. Create threshold = minP * maxProb
// 4. Sort logits and mask tokens where prob < threshold
// 5. Keep at least minTokensToKeep
// 6. Scatter back to original order
// 2. Find max probability per batch (with keepdim)
// 3. Calculate threshold = minP * maxProb
// 4. Create mask for tokens where prob < threshold
// 5. Use topK to get min_tokens_to_keep and unmask them
// 6. Apply mask to scores

let vocabSize = scores.shape[scores.rank - 1]

// Compute probabilities
// Convert logits to probabilities
let probs = scores.softmax(alongAxis: -1)

// Sort probabilities descending to get max (first element)
let sortedProbIndices = probs.argsort(alongAxis: -1, descendingOrder: true)
let sortedProbs = probs.gathering(atIndices: sortedProbIndices, alongAxis: -1)
// Get the probability of the top token for each sequence in the batch
// Using max with keepRank=true to maintain dimensions for broadcasting
let topProbs = probs.max(alongAxes: [-1], keepRank: true)

// Extract max prob per batch: first element of each sorted sequence
// Do this on CPU to avoid complex broadcasting issues
let sortedProbsArray = await sortedProbs.shapedArray(of: Float.self)
let batchSize = scores.shape[0]
var thresholdScalars = [Float]()
thresholdScalars.reserveCapacity(batchSize * vocabSize)
for batchIdx in 0..<batchSize {
let maxProb = sortedProbsArray.scalars[batchIdx * vocabSize] // First element
let thresholdVal = minP * maxProb
for _ in 0..<vocabSize {
thresholdScalars.append(thresholdVal)
}
}
let threshold = MLTensor(shape: probs.shape, scalars: thresholdScalars, scalarType: Float.self)
// Calculate the actual min_p threshold by scaling min_p with the top token's probability
let scaledMinP = topProbs * minP

// Create mask: tokensToRemove where prob < threshold
let tokensToRemove = probs .< threshold
// Create a mask for tokens that have a probability less than the scaled min_p
let tokensToRemove = probs .< scaledMinP

// Sort scores descending
let sortedScoreIndices = scores.argsort(alongAxis: -1, descendingOrder: true)
let inversePermutation = sortedScoreIndices.argsort(alongAxis: -1)
// Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed)
let k = min(minTokensToKeep, vocabSize)

// Gather mask in sorted order
let sortedTokensToRemove = tokensToRemove.gathering(atIndices: sortedScoreIndices, alongAxis: -1)
// Get indices of top-k probabilities
let topKResult = probs.topK(k)
let topKIndices = topKResult.indices

// Create position tensor for minTokensToKeep check
let posBaseShape = Array(repeating: 1, count: scores.rank - 1) + [vocabSize]
var posMultiples = scores.shape
posMultiples[posMultiples.count - 1] = 1
// Create a mask to keep the top-k tokens
// Since MLTensor doesn't have a scatter operation that works like PyTorch's scatter_,
// we use replacing(atIndices:with:alongAxis:) which replaces values at specified indices.
// For our case, we want to unmask (set to False/0) the top-k token positions.

let positions = MLTensor(
rangeFrom: Int32(0),
to: Int32(vocabSize),
by: 1,
scalarType: Int32.self
)
.reshaped(to: posBaseShape)
.tiled(multiples: posMultiples)
// Convert boolean mask to Int32 (1 = remove, 0 = keep)
let zerosInt = MLTensor(repeating: Int32(0), shape: tokensToRemove.shape, scalarType: Int32.self)
let onesInt = MLTensor(repeating: Int32(1), shape: tokensToRemove.shape, scalarType: Int32.self)
let tokensToRemoveAsInt = zerosInt.replacing(with: onesInt, where: tokensToRemove)

// Mask: remove if (position >= minTokensToKeep AND shouldRemove)
let beyondMinimum = positions .>= Int32(minTokensToKeep)
let finalRemoveMask = sortedTokensToRemove .& beyondMinimum
// Try using replacing(atIndices:with:alongAxis:) which takes a scalar value
// This replaces slices at the specified indices with the scalar value
let finalTokensToRemoveInt = tokensToRemoveAsInt.replacing(atIndices: topKIndices, with: Int32(0), alongAxis: -1)

// Apply filter in sorted space
let sortedScores = scores.gathering(atIndices: sortedScoreIndices, alongAxis: -1)
let filterTensor = MLTensor(repeating: filterValue, shape: sortedScores.shape, scalarType: Float.self)
let filteredSorted = sortedScores.replacing(with: filterTensor, where: finalRemoveMask)
// Convert back to boolean mask
let zerosComparison = MLTensor(repeating: Int32(0), shape: tokensToRemove.shape, scalarType: Int32.self)
let finalTokensToRemove = finalTokensToRemoveInt .!= zerosComparison

// Scatter back to original order
return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1)
// Apply mask to scores
let filterTensor = MLTensor(repeating: filterValue, shape: scores.shape, scalarType: Float.self)
return scores.replacing(with: filterTensor, where: finalTokensToRemove)
}
}
#endif