diff --git a/asm2vec/__init__.py b/asm2vec/__init__.py index ae7efea..d3afa2c 100644 --- a/asm2vec/__init__.py +++ b/asm2vec/__init__.py @@ -1 +1 @@ -__all__ = ["datatype", "model", "utils", "version"] +__all__ = ["datatype", "model", "utils", "binary_to_asm", "version"] diff --git a/asm2vec/binary_to_asm.py b/asm2vec/binary_to_asm.py new file mode 100644 index 0000000..fe30d9a --- /dev/null +++ b/asm2vec/binary_to_asm.py @@ -0,0 +1,145 @@ +import re +import os +import hashlib +import r2pipe +import logging +from pathlib import Path + + +def _sha3(asm: str) -> str: + """Produces SHA3 for each assembly function + :param asm: input assembly function + """ + return hashlib.sha3_256(asm.encode()).hexdigest() + + +def _valid_exe(filename: str) -> bool: + """Extracts magic bytes and returns the header + :param filename: name of the malware file (SHA1) + :return: Boolean of the header existing in magic bytes + """ + magics = [bytes.fromhex('cffaedfe')] + with open(filename, 'rb') as f: + header = f.read(4) + return header in magics + + +def _normalize(opcode: str) -> str: + """ Normalizes the input string + :param opcode: opcode of the binary + """ + opcode = opcode.replace(' - ', ' + ') + opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode) + opcode = re.sub(r'\*[0-9]', '*CONST', opcode) + opcode = re.sub(r' [0-9]', ' CONST', opcode) + return opcode + + +def _fn_to_asm(pdf: dict | None, asm_minlen: int) -> str: + """Converts functions to assembly code + :param pdf: disassembly + :param asm_minlen: minimum length of assembly functions to be extracted + """ + if pdf is None: + return '' + if len(pdf['ops']) < asm_minlen: + return '' + if 'invalid' in [op['type'] for op in pdf['ops']]: + return '' + + ops = pdf['ops'] + + labels, scope = {}, [op['offset'] for op in ops] + assert (None not in scope) + for i, op in enumerate(ops): + if op.get('jump') in scope: + labels.setdefault(op.get('jump'), i) + + output = '' + for op in ops: + if labels.get(op.get('offset')) is not None: + output += f'LABEL{labels[op["offset"]]}:\n' + if labels.get(op.get('jump')) is not None: + output += f' {op["type"]} LABEL{labels[op["jump"]]}\n' + else: + output += f' {_normalize(op["opcode"])}\n' + + return output + + +def bin_to_asm(filename: Path, output_path: Path, asm_minlen: int) -> int: + """Fragments the input binary into assembly functions via r2pipe + :param filename: name of the malware file (SHA1) + :param output_path: path to the folder to store the assembly functions for each malware + :param asm_minlen: the minimum length of assembly functions to be extracted + :return: the number of assembly functions + """ + if not _valid_exe(filename): + logging.info('The input file is invalid.') + return 0 + + r = r2pipe.open(str(filename)) + r.cmd('aaaa') + + count = 0 + + for fn in r.cmdj('aflj'): + r.cmd(f's {fn["offset"]}') + asm = _fn_to_asm(r.cmdj('pdfj'), asm_minlen) + if asm: + uid = _sha3(asm) + asm = f''' .name {fn["name"]}\ + .offset {fn["offset"]:016x}\ + .file {filename.name}''' + asm + output_asm = os.path.join(output_path, uid) + with open(output_asm, 'w') as file: + file.write(asm) + count += 1 + return count + + +def convert_to_asm(input_path, output_path, minlen_upper: int, minlen_lower: int) -> list: + """ Extracts assembly functions from malware files and saves them + into separate folder per binary + :param input_path: the path to the malware binaries + :param output_path: the path for the assembly functions to be extracted + :param minlen_upper: The minimum number of assembly functions needed for disassembling + :param minlen_lower: If disassembling not possible with with minlen_upper, lower the minimum number + of assembly functions to minlen_lower + :return: List of sha1 of disassembled malware files + """ + + binary_dir = Path(input_path) + asm_dir = Path(output_path) + + if not os.path.exists(asm_dir): + os.mkdir(asm_dir) + + function_count, binary_count, not_found = 0, 0, 0 + disassembled_bins = [] + + if os.path.isdir(binary_dir): + for entry in os.scandir(binary_dir): + out_dir = os.path.join(asm_dir, entry.name) + if not (os.path.exists(out_dir)): + os.mkdir(out_dir) + function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_upper) + if function_count == 0: + function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_lower) + if function_count == 0: + os.rmdir(out_dir) + logging.info('The binary {} was not disassembled'.format(entry.name)) + else: + binary_count += 1 + disassembled_bins.append(entry.name) + else: + binary_count += 1 + disassembled_bins.append(entry.name) + else: + not_found += 1 + logging.info("[Error] No such file or directory: {}".format(binary_dir)) + + logging.info("Total scanned binaries: {}".format(binary_count)) + logging.info("Not converted binaries: {}".format(not_found)) + + return disassembled_bins diff --git a/asm2vec/datatype.py b/asm2vec/datatype.py index a3cd39b..b6451d8 100644 --- a/asm2vec/datatype.py +++ b/asm2vec/datatype.py @@ -2,19 +2,23 @@ import random import warnings + class Token: def __init__(self, name, index): self.name = name self.index = index self.count = 1 + def __str__(self): return self.name + class Tokens: def __init__(self, name_to_index=None, tokens=None): self.name_to_index = name_to_index or {} self.tokens = tokens or [] self._weights = None + def __getitem__(self, key): if type(key) is str: if self.name_to_index.get(key) is None: @@ -28,13 +32,17 @@ def __getitem__(self, key): return [self[k] for k in key] except: raise ValueError + def load_state_dict(self, sd): self.name_to_index = sd['name_to_index'] self.tokens = sd['tokens'] + def state_dict(self): return {'name_to_index': self.name_to_index, 'tokens': self.tokens} + def size(self): return len(self.tokens) + def add(self, names): self._weights = None if type(names) is not list: @@ -46,6 +54,7 @@ def add(self, names): self.tokens.append(token) else: self.tokens[self.name_to_index[name]].count += 1 + def update(self, tokens_new): for token in tokens_new: if token.name not in self.name_to_index: @@ -54,6 +63,7 @@ def update(self, tokens_new): self.tokens.append(token) else: self.tokens[self.name_to_index[token.name]].count += token.count + def weights(self): # if no cache, calculate if self._weights is None: @@ -62,19 +72,22 @@ def weights(self): for token in self.tokens: self._weights[token.index] = (token.count / total) ** 0.75 return self._weights + def sample(self, batch_size, num=5): return torch.multinomial(self.weights(), num * batch_size, replacement=True).view(batch_size, num) + class Function: def __init__(self, insts, blocks, meta): self.insts = insts self.blocks = blocks self.meta = meta + @classmethod def load(cls, text): - ''' - gcc -S format compatiable - ''' + """gcc -S format compatible + """ + label, labels, insts, blocks, meta = None, {}, [], [], {} for line in text.strip('\n').split('\n'): if line[0] in [' ', '\t']: @@ -109,10 +122,13 @@ def load(cls, text): if labels.get(arg): inst.args[i] = 'CONST' return cls(insts, blocks, meta) + def tokens(self): return [token for inst in self.insts for token in inst.tokens()] + def random_walk(self, num=3): return [self._random_walk() for _ in range(num)] + def _random_walk(self): current, visited, seq = self.blocks[0], [], [] while current not in visited: @@ -124,25 +140,31 @@ def _random_walk(self): current = random.choice(list(current.successors)) return seq + class BasicBlock: def __init__(self): self.insts = [] self.successors = set() + def add(self, inst): self.insts.append(inst) + def end(self): inst = self.insts[-1] return inst.is_jmp() or inst.op == 'ret' + class Instruction: def __init__(self, op, args): self.op = op self.args = args + def __str__(self): return f'{self.op} {", ".join([str(arg) for arg in self.args if str(arg)])}' + @classmethod def load(cls, text): - text = text.strip().strip('bnd').strip() # get rid of BND prefix + text = text.strip().strip('bnd').strip() op, _, args = text.strip().partition(' ') if args: args = [arg.strip() for arg in args.split(',')] @@ -150,9 +172,12 @@ def load(cls, text): args = [] args = (args + ['', ''])[:2] return cls(op, args) + def tokens(self): return [self.op] + self.args + def is_jmp(self): return 'jmp' in self.op or self.op[0] == 'j' + def is_call(self): return self.op == 'call' diff --git a/asm2vec/model.py b/asm2vec/model.py index 301f3be..74a6ace 100644 --- a/asm2vec/model.py +++ b/asm2vec/model.py @@ -3,35 +3,43 @@ bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1) + class ASM2VEC(nn.Module): def __init__(self, vocab_size, function_size, embedding_size): super(ASM2VEC, self).__init__() - self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size)) - self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2) - self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2) + self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size)) + self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, + _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2) + self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, + _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2) def update(self, function_size_new, vocab_size_new): device = self.embeddings.weight.device - vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim + vocab_size, function_size, embedding_size = (self.embeddings.num_embeddings, + self.embeddings_f.num_embeddings, self.embeddings.embedding_dim) if vocab_size_new != vocab_size: - weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)]) + weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size). + to(device)]) self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight) - weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)]) + weight_r = torch.cat([self.embeddings_r.weight, + ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2) + .to(device)]) self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r) - self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device)) + self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, + _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5) / + embedding_size/2).to(device)) def v(self, inp): - e = self.embeddings(inp[:,1:]) - v_f = self.embeddings_f(inp[:,0]) - v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1) - v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1) + e = self.embeddings(inp[:, 1:]) + v_f = self.embeddings_f(inp[:, 0]) + v_prev = torch.cat([e[:, 0], (e[:, 1] + e[:, 2]) / 2], dim=1) + v_next = torch.cat([e[:, 3], (e[:, 4] + e[:, 5]) / 2], dim=1) v = ((v_f + v_prev + v_next) / 3).unsqueeze(2) return v def forward(self, inp, pos, neg): device, batch_size = inp.device, inp.shape[0] v = self.v(inp) - # negative sampling loss pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze() label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device) return bce(sigmoid(pred), label) @@ -39,5 +47,6 @@ def forward(self, inp, pos, neg): def predict(self, inp, pos): device, batch_size = inp.device, inp.shape[0] v = self.v(inp) - probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2) + probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1). + to(device)), v).squeeze(dim=2) return softmax(probs) diff --git a/asm2vec/utils.py b/asm2vec/utils.py index 4f9aa25..b233d33 100644 --- a/asm2vec/utils.py +++ b/asm2vec/utils.py @@ -3,18 +3,22 @@ import torch from torch.utils.data import DataLoader, Dataset from pathlib import Path -from .datatype import Tokens, Function, Instruction -from .model import ASM2VEC +from asm2vec.datatype import Tokens, Function, Instruction +from asm2vec.model import ASM2VEC + class AsmDataset(Dataset): def __init__(self, x, y): self.x = x self.y = y + def __len__(self): return len(self.x) + def __getitem__(self, index): return self.x[index], self.y[index] + def load_data(paths, limit=None): if type(paths) is not list: paths = [paths] @@ -22,7 +26,8 @@ def load_data(paths, limit=None): filenames = [] for path in paths: if os.path.isdir(path): - filenames += [Path(path) / filename for filename in sorted(os.listdir(path)) if os.path.isfile(Path(path) / filename)] + filenames += [Path(path) / filename for filename in sorted(os.listdir(path)) + if os.path.isfile(Path(path) / filename)] else: filenames += [Path(path)] @@ -37,6 +42,7 @@ def load_data(paths, limit=None): return functions, tokens + def preprocess(functions, tokens): x, y = [], [] for i, fn in enumerate(functions): @@ -46,6 +52,7 @@ def preprocess(functions, tokens): y.append([tokens[token].index for token in seq[j].tokens()]) return torch.tensor(x), torch.tensor(y) + def train( functions, tokens, @@ -102,6 +109,7 @@ def train( return model + def save_model(path, model, tokens): torch.save({ 'model_params': ( @@ -113,6 +121,7 @@ def save_model(path, model, tokens): 'tokens': tokens.state_dict(), }, path) + def load_model(path, device='cpu'): checkpoint = torch.load(path, map_location=device) tokens = Tokens() @@ -122,35 +131,37 @@ def load_model(path, device='cpu'): model = model.to(device) return model, tokens + def show_probs(x, y, probs, tokens, limit=None, pretty=False): if pretty: - TL, TR, BL, BR = '┌', '┐', '└', '┘' - LM, RM, TM, BM = '├', '┤', '┬', '┴' - H, V = '─', '│' + tl, tr, bl, br = '┌', '┐', '└', '┘' + lm, rm, tm, bm = '├', '┤', '┬', '┴' + h, v = '─', '│' arrow = ' ➔' else: - TL = TR = BL = BR = '+' - LM = RM = TM = BM = '+' - H, V = '-', '|' + tl, tr, bl, br = '+', '+', '+', '+' + lm, rm, tm, bm = '+', '+', '+', '+' + h, v = '-', '|' arrow = '->' top = probs.topk(5) for i, (xi, yi) in enumerate(zip(x, y)): if limit and i >= limit: break xi, yi = xi.tolist(), yi.tolist() - print(TL + H * 42 + TR) - print(f'{V} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {V}') - print(f'{V} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {V}') - print(f'{V} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {V}') - print(LM + H * 8 + TM + H * 33 + RM) + print(tl + h * 42 + tr) + print(f'{v} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {v}') + print(f'{v} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {v}') + print(f'{v} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {v}') + print(lm + h * 8 + tm + h * 33 + rm) for value, index in zip(top.values[i], top.indices[i]): if index in yi: colorbegin, colorclear = '\033[92m', '\033[0m' else: colorbegin, colorclear = '', '' - print(f'{V} {colorbegin}{value*100:05.2f}%{colorclear} {V} {colorbegin}{tokens[index.item()].name:31}{colorclear} {V}') - print(BL + H * 8 + BM + H * 33 + BR) + print(f'{v} {colorbegin}{value*100:05.2f}%{colorclear} {v} {colorbegin}' + f'{tokens[index.item()].name:31}{colorclear} {v}') + print(bl + h * 8 + bm + h * 33 + br) + def accuracy(y, probs): return torch.mean(torch.tensor([torch.sum(probs[i][yi]) for i, yi in enumerate(y)])) - diff --git a/scripts/bin2asm.py b/scripts/bin2asm.py deleted file mode 100644 index 2134e8c..0000000 --- a/scripts/bin2asm.py +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env python3 -import re -import os -import click -import r2pipe -import hashlib -from pathlib import Path - -def sha3(data): - return hashlib.sha3_256(data.encode()).hexdigest() - -def validEXE(filename): - magics = [bytes.fromhex('7f454c46')] - with open(filename, 'rb') as f: - header = f.read(4) - return header in magics - -def normalize(opcode): - opcode = opcode.replace(' - ', ' + ') - opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode) - opcode = re.sub(r'\*[0-9]', '*CONST', opcode) - opcode = re.sub(r' [0-9]', ' CONST', opcode) - return opcode - -def fn2asm(pdf, minlen): - # check - if pdf is None: - return - if len(pdf['ops']) < minlen: - return - if 'invalid' in [op['type'] for op in pdf['ops']]: - return - - ops = pdf['ops'] - - # set label - labels, scope = {}, [op['offset'] for op in ops] - assert(None not in scope) - for i, op in enumerate(ops): - if op.get('jump') in scope: - labels.setdefault(op.get('jump'), i) - - # dump output - output = '' - for op in ops: - # add label - if labels.get(op.get('offset')) is not None: - output += f'LABEL{labels[op["offset"]]}:\n' - # add instruction - if labels.get(op.get('jump')) is not None: - output += f' {op["type"]} LABEL{labels[op["jump"]]}\n' - else: - output += f' {normalize(op["opcode"])}\n' - - return output - -def bin2asm(filename, opath, minlen): - # check - if not validEXE(filename): - return 0 - - r = r2pipe.open(str(filename)) - r.cmd('aaaa') - - count = 0 - - for fn in r.cmdj('aflj'): - r.cmd(f's {fn["offset"]}') - asm = fn2asm(r.cmdj('pdfj'), minlen) - if asm: - uid = sha3(asm) - asm = f''' .name {fn["name"]} - .offset {fn["offset"]:016x} - .file {filename.name} -''' + asm - with open(opath / uid, 'w') as f: - f.write(asm) - count += 1 - - print(f'[+] {filename}') - - return count - -@click.command() -@click.option('-i', '--input', 'ipath', help='input directory / file', required=True) -@click.option('-o', '--output', 'opath', default='asm', help='output directory') -@click.option('-l', '--len', 'minlen', default=10, help='ignore assembly code with instructions amount smaller than minlen') -def cli(ipath, opath, minlen): - ''' - Extract assembly functions from binary executable - ''' - ipath = Path(ipath) - opath = Path(opath) - - # create output directory - if not os.path.exists(opath): - os.mkdir(opath) - - fcount, bcount = 0, 0 - - # directory - if os.path.isdir(ipath): - for f in os.listdir(ipath): - if not os.path.islink(ipath / f) and not os.path.isdir(ipath / f): - fcount += bin2asm(ipath / f, opath, minlen) - bcount += 1 - # file - elif os.path.exists(ipath): - fcount += bin2asm(ipath, opath, minlen) - bcount += 1 - else: - print(f'[Error] No such file or directory: {ipath}') - - print(f'[+] Total scan binary: {bcount} => Total generated assembly functions: {fcount}') - -if __name__ == '__main__': - cli()