Skip to content
34 changes: 23 additions & 11 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,21 @@ def setup_parser():
parser.add_argument("--gen-length", type=int, default=32, help="Batch Size")
parser.add_argument("--seed", type=int, default=1234, help="Seed")
parser.add_argument("--use-optimized-code", action='store_true', default=False)
parser.add_argument("--warmup-iters", type=int, default=5)
parser.add_argument("--total-iters", type=int, default=10)

return parser

def load_prompts(tokenizer, batch_size, prompt_length):
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
total_tokens = encodings.input_ids.shape[1]
start_index = min(random.randint(0, total_tokens), total_tokens - batch_size * prompt_length)
input_ids = encodings.input_ids[:, start_index : start_index + batch_size * prompt_length].reshape(batch_size, prompt_length)
input_ids = []
for _ in range(batch_size):
start_index = min(random.randint(0, total_tokens), total_tokens - prompt_length)
tokens = encodings.input_ids[:, start_index : start_index + prompt_length].reshape(1, prompt_length)
input_ids.append(tokens)
input_ids = torch.cat(input_ids, dim=0)
return input_ids

if __name__ == "__main__":
Expand All @@ -62,8 +68,8 @@ def load_prompts(tokenizer, batch_size, prompt_length):
set_seed(args.seed)

if args.method == "pca-topk":
args.top_k = args.prompt_length
args.top_r = 128
args.top_k = int(0.25 * args.prompt_length)
args.top_r = 32
args.rotary_type = "postrotary"

if args.use_optimized_code:
Expand All @@ -84,20 +90,26 @@ def load_prompts(tokenizer, batch_size, prompt_length):

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
generations = []

input_ids = tokenized_prompts.cuda()
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length, num_beams=4)

# warmup iters
for _ in range(args.warmup_iters):
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length)

# timed iters
start_event.record()
for _ in range(args.total_iters - args.warmup_iters):
with torch.autocast(device_type='cuda', dtype=dtype):
outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length)
end_event.record()


generated_tokens = outputs.numel() - input_ids.numel()
total_generated_tokens += generated_tokens

torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event)
total_time = start_event.elapsed_time(end_event) / (args.total_iters - args.warmup_iters)
tput = total_generated_tokens * 1000 / total_time

output_ids = outputs[:, args.prompt_length:]
Expand Down
254 changes: 171 additions & 83 deletions methods/pca_topk/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import time

import torch
import external.gather_matmul as G

#import external.gather_matmul as G
import kernel.pca_topk as G


topk_time = 0
Expand All @@ -16,6 +16,8 @@
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False

from timers import Timers
import json

# Work In Progress
class PcaTopKCache(Cache): # Not used anymore
Expand Down Expand Up @@ -120,6 +122,10 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

def reset(self):
self.key_cache: List[torch.Tensor] = [] # Stores the reduced keys for each layer
self.value_cache: List[torch.Tensor] = []

def test_pcatopk_cache():
cache = PcaTopKCache(2, 4)

Expand All @@ -139,65 +145,98 @@ def test_pcatopk_cache():



def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=2000, use_optimised_gather=False):
def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_layers, timers,
num_gen_steps=2000, use_optimised_gather=False):
import time
torch.set_float32_matmul_precision("highest")

head_dim = prompt_keys.shape[-1]
bs = prompt_keys.shape[0]
num_heads = prompt_keys.shape[1]

generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
head_dim = prompt_keys[0].shape[-1]
bs = prompt_keys[0].shape[0]
num_heads = prompt_keys[0].shape[1]
dtype = prompt_keys[0].dtype
prompt_seq_length = prompt_keys[0].shape[2]

print ("Starting microbenchmark")
matmul_time = 0
top_keys = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda")
top_vals = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda")
pca_projection_mat = torch.randn(num_heads, head_dim, head_dim, dtype=dtype, device='cuda')


assert use_optimised_gather
if use_optimised_gather:
timers.start('total')
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)

# Get top-k keys and top-k values based on the attention scores
key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda")
key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1])

keys = keys.reshape(-1, keys.shape[-2] , keys.shape[-1])
vals = vals.reshape(-1, vals.shape[-2] , vals.shape[-1])

attn_weights = G.gather_outer_bmv(
generative_query.reshape(-1, 1, head_dim),
keys.transpose(-1, -2),
key_states_topk_indices,
#.squeeze(0).squeeze(-1),
chunk=256
#chunk=min(k2, 65536 // Q.shape[-1]),
) / math.sqrt(head_dim)
attn_weights = torch.softmax(attn_weights, dim=-1)

attn_output = G.gather_inner_matrix_only_bmv(
attn_weights, vals, key_states_topk_indices, chunk=64
)

torch.cuda.synchronize()
end = time.time()

if i > 5:
matmul_time += end - start
for layer in range(num_layers):
timers.start('qk-gen')
generative_query = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype)
generative_key = torch.rand(bs, num_heads, 1, head_dim, device='cuda', dtype=dtype)
timers.stop('qk-gen')

timers.start('project')
generative_key = generative_key.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2)
generative_query = generative_query.squeeze().transpose(0, 1).bmm(pca_projection_mat).unsqueeze(2)
timers.stop('project')

timers.start('cache-update')
keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False)
timers.stop('cache-update')

timers.start('qk-matmul-1')
#attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)
nh, bs, s, r = keys.shape
attn_weights = G.topr_bmv_optimized(A=generative_query.view(nh*bs, 1, r), B=keys.view(nh*bs, s, r).transpose(-1,-2),
r=top_r)
attn_weights = attn_weights.view(nh, bs, 1, s)
timers.stop('qk-matmul-1')

# Get top-k keys and top-k values based on the attention scores
timers.start('top-k')
#key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1, sorted=False).indices.to("cuda")
#key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
key_states_topk_indices = torch.argsort(attn_weights, dim=-1, descending=True)[:,:,:,:top_k]
timers.stop('top-k')


timers.start('reshape-0')
key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1])
timers.stop('reshape-0')

timers.start('reshape-1')
keys = keys.view(-1, keys.shape[-2] , keys.shape[-1])
vals = vals.view(-1, vals.shape[-2] , vals.shape[-1])
timers.stop('reshape-1')

timers.start('qk-matmul-2')
attn_weights = G.gather_outer_bmv_optimized(
generative_query.reshape(-1, 1, head_dim),
keys.transpose(-1, -2),
key_states_topk_indices,
#.squeeze(0).squeeze(-1),
#chunk=256
#chunk=min(k2, 65536 // Q.shape[-1]),
) / math.sqrt(head_dim)
timers.stop('qk-matmul-2')

timers.start('softmax')
attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype)
timers.stop('softmax')

timers.start('sv-matmul')
attn_output = G.gather_inner_matrix_only_bmv_optimized(
attn_weights, vals, key_states_topk_indices)
timers.stop('sv-matmul')

timers.start('reshape-output')
attn_output = attn_output.view(num_heads, bs, 1, head_dim).transpose(0,1).transpose(1,2).contiguous()
timers.stop('reshape-output')
timers.stop('total')
else:
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim)

# Get top-k keys and top-k values based on the attention scores
key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda")
key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1)
Expand All @@ -212,70 +251,119 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=200
torch.cuda.synchronize()
end = time.time()

if i > 5:
matmul_time += end - start
print (f"Matmul Time: {matmul_time}")

def micro_bench_actual_attention(cache, prompt_keys, num_gen_steps=2000):
def micro_bench_actual_attention(cache, prompt_keys, num_layers, timers, num_gen_steps=2000):
import time
torch.set_float32_matmul_precision("highest")

head_dim = prompt_keys.shape[-1]
bs = prompt_keys.shape[0]
num_heads = prompt_keys.shape[1]

generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda")
head_dim = prompt_keys[0].shape[-1]
bs = prompt_keys[0].shape[0]
num_heads = prompt_keys[0].shape[1]
dtype = prompt_keys[0].dtype

print ("Starting microbenchmark")
matmul_time = 0
for i in range(num_gen_steps):
keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False)
torch.cuda.synchronize()

start = time.time()
attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, vals)
torch.cuda.synchronize()
end = time.time()
timers.start('total')
for i in range(num_gen_steps):
for layer in range(num_layers):
timers.start('qk-gen')
generative_query = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda')
generative_key = torch.rand(bs, num_heads, 1, head_dim, dtype=dtype, device='cuda')
timers.stop('qk-gen')

timers.start('cache-update')
keys, vals = cache.update(generative_key, generative_key, generative_query, layer, False)
timers.stop('cache-update')

timers.start('qk-matmul-1')
attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim)
timers.stop('qk-matmul-1')

timers.start('softmax')
attn_weights = torch.softmax(attn_weights.float(), dim=-1).to(dtype)
timers.stop('softmax')

timers.start('sv-matmul')
attn_output = torch.matmul(attn_weights, vals)
timers.stop('sv-matmul')

timers.start('reshape-output')
attn_output = attn_output.transpose(1, 2).contiguous()
timers.stop('reshape-output')


if i > 5:
matmul_time += end - start
print (f"Matmul Time: {matmul_time}")
timers.stop('total')

@torch.no_grad()
def benchmark_attention(batch_size=1,
num_heads=32,
num_gen_steps=128,
prompt_length=3072,
topk=256):
topk=256,
num_layers=32,
dtype=torch.float16):

head_dim=128
# Change this to change batch size, etc.
prompt_keys = torch.rand(batch_size, num_heads, prompt_length, head_dim).to("cuda")
prompt_keys = [torch.rand(batch_size, num_heads, prompt_length, head_dim, device='cuda', dtype=dtype) for _ in range(num_layers)]


print("PCA TOPK Unoptimized")
cache1 = PcaTopKCache()
cache1.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps)
del cache1
#print("PCA TOPK Unoptimized")
#cache1 = [PcaTopKCache() for _ in range(num_layers)]
#for i in range(num_layers):
# cache1[i].update(prompt_keys[i], prompt_keys[i], prompt_keys[i], 0)
#micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps)
#del cache1


print("PCA TOPK Optimized")
cache2 = PcaTopKCache()
cache2.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk, num_gen_steps=num_gen_steps, use_optimised_gather=True)
del cache2
for _ in range(10):
cache2 = PcaTopKCache()
for i in range(num_layers):
cache2.update(prompt_keys[i].transpose(0,1).contiguous(),
prompt_keys[i].transpose(0,1).contiguous(),
prompt_keys[i].transpose(0,1).contiguous(), i)
timers = Timers()
micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk,
num_gen_steps=num_gen_steps, num_layers=num_layers,
use_optimised_gather=True, timers=timers)
del cache2
times = timers.get_times()
print(times)

print("Average time (minus cache updates) is - ")
print(times['total'] - times['cache-update'], " s")
print("==================================")
times_pca_topk = times

print("Actual Attention")
cache3= PcaTopKCache()
cache3.update(prompt_keys, prompt_keys, prompt_keys, 0)
micro_bench_actual_attention(cache3, prompt_keys, num_gen_steps=num_gen_steps)
del cache3
for _ in range(10):
cache3= PcaTopKCache()
for i in range(num_layers):
cache3.update(prompt_keys[i], prompt_keys[i], prompt_keys[i], i)
timers = Timers()
micro_bench_actual_attention(cache3, prompt_keys, num_layers=num_layers,
num_gen_steps=num_gen_steps, timers=timers)
del cache3
times = timers.get_times()
print("Average time (minus cache updates) is - ")
print(times['total'] - times['cache-update'], " s")
print(times)
print("==================================")
times_vanilla = times
return times_pca_topk, times_vanilla

if __name__ == "__main__":
#test_pcatopk_cache()
with torch.no_grad():
benchmark_attention(prompt_length=4096, num_gen_steps=2000, batch_size=16, topk=1024)
prompt_length = 2000
for num_gen_steps in [1000]:
print(f"prompt length = {prompt_length}, gen length = {num_gen_steps}, batch_size={16}, topk and top r are 25%")
times_pca_topk, times_vanilla = benchmark_attention(prompt_length=prompt_length, num_gen_steps=num_gen_steps, batch_size=16, topk=prompt_length // 4, num_layers=1)
with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_pca_topk_opt_first_matmul.json", "w") as f:
json.dump(times_pca_topk, f, indent=2)

with open(f"prompt_{prompt_length}_gen_{num_gen_steps}_vanilla.json", "w") as f:
json.dump(times_vanilla, f, indent=2)


Empty file.
Loading