Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asm2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__all__ = ["datatype", "model", "utils", "version"]
__all__ = ["datatype", "model", "utils", "binary_to_asm", "version"]
145 changes: 145 additions & 0 deletions asm2vec/binary_to_asm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import re
import os
import hashlib
import r2pipe
import logging
from pathlib import Path


def _sha3(asm: str) -> str:
"""Produces SHA3 for each assembly function
:param asm: input assembly function
"""
return hashlib.sha3_256(asm.encode()).hexdigest()


def _valid_exe(filename: str) -> bool:
"""Extracts magic bytes and returns the header
:param filename: name of the malware file (SHA1)
:return: Boolean of the header existing in magic bytes
"""
magics = [bytes.fromhex('cffaedfe')]
with open(filename, 'rb') as f:
header = f.read(4)
return header in magics


def _normalize(opcode: str) -> str:
""" Normalizes the input string
:param opcode: opcode of the binary
"""
opcode = opcode.replace(' - ', ' + ')
opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode)
opcode = re.sub(r'\*[0-9]', '*CONST', opcode)
opcode = re.sub(r' [0-9]', ' CONST', opcode)
return opcode


def _fn_to_asm(pdf: dict | None, asm_minlen: int) -> str:
"""Converts functions to assembly code
:param pdf: disassembly
:param asm_minlen: minimum length of assembly functions to be extracted
"""
if pdf is None:
return ''
if len(pdf['ops']) < asm_minlen:
return ''
if 'invalid' in [op['type'] for op in pdf['ops']]:
return ''

ops = pdf['ops']

labels, scope = {}, [op['offset'] for op in ops]
assert (None not in scope)
for i, op in enumerate(ops):
if op.get('jump') in scope:
labels.setdefault(op.get('jump'), i)

output = ''
for op in ops:
if labels.get(op.get('offset')) is not None:
output += f'LABEL{labels[op["offset"]]}:\n'
if labels.get(op.get('jump')) is not None:
output += f' {op["type"]} LABEL{labels[op["jump"]]}\n'
else:
output += f' {_normalize(op["opcode"])}\n'

return output


def bin_to_asm(filename: Path, output_path: Path, asm_minlen: int) -> int:
"""Fragments the input binary into assembly functions via r2pipe
:param filename: name of the malware file (SHA1)
:param output_path: path to the folder to store the assembly functions for each malware
:param asm_minlen: the minimum length of assembly functions to be extracted
:return: the number of assembly functions
"""
if not _valid_exe(filename):
logging.info('The input file is invalid.')
return 0

r = r2pipe.open(str(filename))
r.cmd('aaaa')

count = 0

for fn in r.cmdj('aflj'):
r.cmd(f's {fn["offset"]}')
asm = _fn_to_asm(r.cmdj('pdfj'), asm_minlen)
if asm:
uid = _sha3(asm)
asm = f''' .name {fn["name"]}\
.offset {fn["offset"]:016x}\
.file {filename.name}''' + asm
output_asm = os.path.join(output_path, uid)
with open(output_asm, 'w') as file:
file.write(asm)
count += 1
return count


def convert_to_asm(input_path, output_path, minlen_upper: int, minlen_lower: int) -> list:
""" Extracts assembly functions from malware files and saves them
into separate folder per binary
:param input_path: the path to the malware binaries
:param output_path: the path for the assembly functions to be extracted
:param minlen_upper: The minimum number of assembly functions needed for disassembling
:param minlen_lower: If disassembling not possible with with minlen_upper, lower the minimum number
of assembly functions to minlen_lower
:return: List of sha1 of disassembled malware files
"""

binary_dir = Path(input_path)
asm_dir = Path(output_path)

if not os.path.exists(asm_dir):
os.mkdir(asm_dir)

function_count, binary_count, not_found = 0, 0, 0
disassembled_bins = []

if os.path.isdir(binary_dir):
for entry in os.scandir(binary_dir):
out_dir = os.path.join(asm_dir, entry.name)
if not (os.path.exists(out_dir)):
os.mkdir(out_dir)
function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_upper)
if function_count == 0:
function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_lower)
if function_count == 0:
os.rmdir(out_dir)
logging.info('The binary {} was not disassembled'.format(entry.name))
else:
binary_count += 1
disassembled_bins.append(entry.name)
else:
binary_count += 1
disassembled_bins.append(entry.name)
else:
not_found += 1
logging.info("[Error] No such file or directory: {}".format(binary_dir))

logging.info("Total scanned binaries: {}".format(binary_count))
logging.info("Not converted binaries: {}".format(not_found))

return disassembled_bins
33 changes: 29 additions & 4 deletions asm2vec/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
import random
import warnings


class Token:
def __init__(self, name, index):
self.name = name
self.index = index
self.count = 1

def __str__(self):
return self.name


class Tokens:
def __init__(self, name_to_index=None, tokens=None):
self.name_to_index = name_to_index or {}
self.tokens = tokens or []
self._weights = None

def __getitem__(self, key):
if type(key) is str:
if self.name_to_index.get(key) is None:
Expand All @@ -28,13 +32,17 @@ def __getitem__(self, key):
return [self[k] for k in key]
except:
raise ValueError

def load_state_dict(self, sd):
self.name_to_index = sd['name_to_index']
self.tokens = sd['tokens']

def state_dict(self):
return {'name_to_index': self.name_to_index, 'tokens': self.tokens}

def size(self):
return len(self.tokens)

def add(self, names):
self._weights = None
if type(names) is not list:
Expand All @@ -46,6 +54,7 @@ def add(self, names):
self.tokens.append(token)
else:
self.tokens[self.name_to_index[name]].count += 1

def update(self, tokens_new):
for token in tokens_new:
if token.name not in self.name_to_index:
Expand All @@ -54,6 +63,7 @@ def update(self, tokens_new):
self.tokens.append(token)
else:
self.tokens[self.name_to_index[token.name]].count += token.count

def weights(self):
# if no cache, calculate
if self._weights is None:
Expand All @@ -62,19 +72,22 @@ def weights(self):
for token in self.tokens:
self._weights[token.index] = (token.count / total) ** 0.75
return self._weights

def sample(self, batch_size, num=5):
return torch.multinomial(self.weights(), num * batch_size, replacement=True).view(batch_size, num)


class Function:
def __init__(self, insts, blocks, meta):
self.insts = insts
self.blocks = blocks
self.meta = meta

@classmethod
def load(cls, text):
'''
gcc -S format compatiable
'''
"""gcc -S format compatible
"""

label, labels, insts, blocks, meta = None, {}, [], [], {}
for line in text.strip('\n').split('\n'):
if line[0] in [' ', '\t']:
Expand Down Expand Up @@ -109,10 +122,13 @@ def load(cls, text):
if labels.get(arg):
inst.args[i] = 'CONST'
return cls(insts, blocks, meta)

def tokens(self):
return [token for inst in self.insts for token in inst.tokens()]

def random_walk(self, num=3):
return [self._random_walk() for _ in range(num)]

def _random_walk(self):
current, visited, seq = self.blocks[0], [], []
while current not in visited:
Expand All @@ -124,35 +140,44 @@ def _random_walk(self):
current = random.choice(list(current.successors))
return seq


class BasicBlock:
def __init__(self):
self.insts = []
self.successors = set()

def add(self, inst):
self.insts.append(inst)

def end(self):
inst = self.insts[-1]
return inst.is_jmp() or inst.op == 'ret'


class Instruction:
def __init__(self, op, args):
self.op = op
self.args = args

def __str__(self):
return f'{self.op} {", ".join([str(arg) for arg in self.args if str(arg)])}'

@classmethod
def load(cls, text):
text = text.strip().strip('bnd').strip() # get rid of BND prefix
text = text.strip().strip('bnd').strip()
op, _, args = text.strip().partition(' ')
if args:
args = [arg.strip() for arg in args.split(',')]
else:
args = []
args = (args + ['', ''])[:2]
return cls(op, args)

def tokens(self):
return [self.op] + self.args

def is_jmp(self):
return 'jmp' in self.op or self.op[0] == 'j'

def is_call(self):
return self.op == 'call'
35 changes: 22 additions & 13 deletions asm2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,50 @@

bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1)


class ASM2VEC(nn.Module):
def __init__(self, vocab_size, function_size, embedding_size):
super(ASM2VEC, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size))
self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2)
self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size))
self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size,
_weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2)
self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size,
_weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)

def update(self, function_size_new, vocab_size_new):
device = self.embeddings.weight.device
vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim
vocab_size, function_size, embedding_size = (self.embeddings.num_embeddings,
self.embeddings_f.num_embeddings, self.embeddings.embedding_dim)
if vocab_size_new != vocab_size:
weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)])
weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).
to(device)])
self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight)
weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)])
weight_r = torch.cat([self.embeddings_r.weight,
((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
.to(device)])
self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r)
self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device))
self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size,
_weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5) /
embedding_size/2).to(device))

def v(self, inp):
e = self.embeddings(inp[:,1:])
v_f = self.embeddings_f(inp[:,0])
v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1)
v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1)
e = self.embeddings(inp[:, 1:])
v_f = self.embeddings_f(inp[:, 0])
v_prev = torch.cat([e[:, 0], (e[:, 1] + e[:, 2]) / 2], dim=1)
v_next = torch.cat([e[:, 3], (e[:, 4] + e[:, 5]) / 2], dim=1)
v = ((v_f + v_prev + v_next) / 3).unsqueeze(2)
return v

def forward(self, inp, pos, neg):
device, batch_size = inp.device, inp.shape[0]
v = self.v(inp)
# negative sampling loss
pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze()
label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device)
return bce(sigmoid(pred), label)

def predict(self, inp, pos):
device, batch_size = inp.device, inp.shape[0]
v = self.v(inp)
probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2)
probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).
to(device)), v).squeeze(dim=2)
return softmax(probs)
Loading