Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 5497a20

Browse files
Merge pull request #2 from wandera/AEGIS-6405-asm2vec-pytorch-edits
Aegis 6405 asm2vec pytorch edits
2 parents 20df9cc + 6632b19 commit 5497a20

File tree

6 files changed

+225
-152
lines changed

6 files changed

+225
-152
lines changed

asm2vec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__all__ = ["datatype", "model", "utils", "version"]
1+
__all__ = ["datatype", "model", "utils", "binary_to_asm", "version"]

asm2vec/binary_to_asm.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import re
2+
import os
3+
import hashlib
4+
import r2pipe
5+
import logging
6+
from pathlib import Path
7+
8+
9+
def _sha3(asm: str) -> str:
10+
"""Produces SHA3 for each assembly function
11+
:param asm: input assembly function
12+
"""
13+
return hashlib.sha3_256(asm.encode()).hexdigest()
14+
15+
16+
def _valid_exe(filename: str) -> bool:
17+
"""Extracts magic bytes and returns the header
18+
:param filename: name of the malware file (SHA1)
19+
:return: Boolean of the header existing in magic bytes
20+
"""
21+
magics = [bytes.fromhex('cffaedfe')]
22+
with open(filename, 'rb') as f:
23+
header = f.read(4)
24+
return header in magics
25+
26+
27+
def _normalize(opcode: str) -> str:
28+
""" Normalizes the input string
29+
:param opcode: opcode of the binary
30+
"""
31+
opcode = opcode.replace(' - ', ' + ')
32+
opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode)
33+
opcode = re.sub(r'\*[0-9]', '*CONST', opcode)
34+
opcode = re.sub(r' [0-9]', ' CONST', opcode)
35+
return opcode
36+
37+
38+
def _fn_to_asm(pdf: dict | None, asm_minlen: int) -> str:
39+
"""Converts functions to assembly code
40+
:param pdf: disassembly
41+
:param asm_minlen: minimum length of assembly functions to be extracted
42+
"""
43+
if pdf is None:
44+
return ''
45+
if len(pdf['ops']) < asm_minlen:
46+
return ''
47+
if 'invalid' in [op['type'] for op in pdf['ops']]:
48+
return ''
49+
50+
ops = pdf['ops']
51+
52+
labels, scope = {}, [op['offset'] for op in ops]
53+
assert (None not in scope)
54+
for i, op in enumerate(ops):
55+
if op.get('jump') in scope:
56+
labels.setdefault(op.get('jump'), i)
57+
58+
output = ''
59+
for op in ops:
60+
if labels.get(op.get('offset')) is not None:
61+
output += f'LABEL{labels[op["offset"]]}:\n'
62+
if labels.get(op.get('jump')) is not None:
63+
output += f' {op["type"]} LABEL{labels[op["jump"]]}\n'
64+
else:
65+
output += f' {_normalize(op["opcode"])}\n'
66+
67+
return output
68+
69+
70+
def bin_to_asm(filename: Path, output_path: Path, asm_minlen: int) -> int:
71+
"""Fragments the input binary into assembly functions via r2pipe
72+
:param filename: name of the malware file (SHA1)
73+
:param output_path: path to the folder to store the assembly functions for each malware
74+
:param asm_minlen: the minimum length of assembly functions to be extracted
75+
:return: the number of assembly functions
76+
"""
77+
if not _valid_exe(filename):
78+
logging.info('The input file is invalid.')
79+
return 0
80+
81+
r = r2pipe.open(str(filename))
82+
r.cmd('aaaa')
83+
84+
count = 0
85+
86+
for fn in r.cmdj('aflj'):
87+
r.cmd(f's {fn["offset"]}')
88+
asm = _fn_to_asm(r.cmdj('pdfj'), asm_minlen)
89+
if asm:
90+
uid = _sha3(asm)
91+
asm = f''' .name {fn["name"]}\
92+
.offset {fn["offset"]:016x}\
93+
.file {filename.name}''' + asm
94+
output_asm = os.path.join(output_path, uid)
95+
with open(output_asm, 'w') as file:
96+
file.write(asm)
97+
count += 1
98+
return count
99+
100+
101+
def convert_to_asm(input_path, output_path, minlen_upper: int, minlen_lower: int) -> list:
102+
""" Extracts assembly functions from malware files and saves them
103+
into separate folder per binary
104+
:param input_path: the path to the malware binaries
105+
:param output_path: the path for the assembly functions to be extracted
106+
:param minlen_upper: The minimum number of assembly functions needed for disassembling
107+
:param minlen_lower: If disassembling not possible with with minlen_upper, lower the minimum number
108+
of assembly functions to minlen_lower
109+
:return: List of sha1 of disassembled malware files
110+
"""
111+
112+
binary_dir = Path(input_path)
113+
asm_dir = Path(output_path)
114+
115+
if not os.path.exists(asm_dir):
116+
os.mkdir(asm_dir)
117+
118+
function_count, binary_count, not_found = 0, 0, 0
119+
disassembled_bins = []
120+
121+
if os.path.isdir(binary_dir):
122+
for entry in os.scandir(binary_dir):
123+
out_dir = os.path.join(asm_dir, entry.name)
124+
if not (os.path.exists(out_dir)):
125+
os.mkdir(out_dir)
126+
function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_upper)
127+
if function_count == 0:
128+
function_count += bin_to_asm(Path(entry), Path(out_dir), minlen_lower)
129+
if function_count == 0:
130+
os.rmdir(out_dir)
131+
logging.info('The binary {} was not disassembled'.format(entry.name))
132+
else:
133+
binary_count += 1
134+
disassembled_bins.append(entry.name)
135+
else:
136+
binary_count += 1
137+
disassembled_bins.append(entry.name)
138+
else:
139+
not_found += 1
140+
logging.info("[Error] No such file or directory: {}".format(binary_dir))
141+
142+
logging.info("Total scanned binaries: {}".format(binary_count))
143+
logging.info("Not converted binaries: {}".format(not_found))
144+
145+
return disassembled_bins

asm2vec/datatype.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@
22
import random
33
import warnings
44

5+
56
class Token:
67
def __init__(self, name, index):
78
self.name = name
89
self.index = index
910
self.count = 1
11+
1012
def __str__(self):
1113
return self.name
1214

15+
1316
class Tokens:
1417
def __init__(self, name_to_index=None, tokens=None):
1518
self.name_to_index = name_to_index or {}
1619
self.tokens = tokens or []
1720
self._weights = None
21+
1822
def __getitem__(self, key):
1923
if type(key) is str:
2024
if self.name_to_index.get(key) is None:
@@ -28,13 +32,17 @@ def __getitem__(self, key):
2832
return [self[k] for k in key]
2933
except:
3034
raise ValueError
35+
3136
def load_state_dict(self, sd):
3237
self.name_to_index = sd['name_to_index']
3338
self.tokens = sd['tokens']
39+
3440
def state_dict(self):
3541
return {'name_to_index': self.name_to_index, 'tokens': self.tokens}
42+
3643
def size(self):
3744
return len(self.tokens)
45+
3846
def add(self, names):
3947
self._weights = None
4048
if type(names) is not list:
@@ -46,6 +54,7 @@ def add(self, names):
4654
self.tokens.append(token)
4755
else:
4856
self.tokens[self.name_to_index[name]].count += 1
57+
4958
def update(self, tokens_new):
5059
for token in tokens_new:
5160
if token.name not in self.name_to_index:
@@ -54,6 +63,7 @@ def update(self, tokens_new):
5463
self.tokens.append(token)
5564
else:
5665
self.tokens[self.name_to_index[token.name]].count += token.count
66+
5767
def weights(self):
5868
# if no cache, calculate
5969
if self._weights is None:
@@ -62,19 +72,22 @@ def weights(self):
6272
for token in self.tokens:
6373
self._weights[token.index] = (token.count / total) ** 0.75
6474
return self._weights
75+
6576
def sample(self, batch_size, num=5):
6677
return torch.multinomial(self.weights(), num * batch_size, replacement=True).view(batch_size, num)
6778

79+
6880
class Function:
6981
def __init__(self, insts, blocks, meta):
7082
self.insts = insts
7183
self.blocks = blocks
7284
self.meta = meta
85+
7386
@classmethod
7487
def load(cls, text):
75-
'''
76-
gcc -S format compatiable
77-
'''
88+
"""gcc -S format compatible
89+
"""
90+
7891
label, labels, insts, blocks, meta = None, {}, [], [], {}
7992
for line in text.strip('\n').split('\n'):
8093
if line[0] in [' ', '\t']:
@@ -109,10 +122,13 @@ def load(cls, text):
109122
if labels.get(arg):
110123
inst.args[i] = 'CONST'
111124
return cls(insts, blocks, meta)
125+
112126
def tokens(self):
113127
return [token for inst in self.insts for token in inst.tokens()]
128+
114129
def random_walk(self, num=3):
115130
return [self._random_walk() for _ in range(num)]
131+
116132
def _random_walk(self):
117133
current, visited, seq = self.blocks[0], [], []
118134
while current not in visited:
@@ -124,35 +140,44 @@ def _random_walk(self):
124140
current = random.choice(list(current.successors))
125141
return seq
126142

143+
127144
class BasicBlock:
128145
def __init__(self):
129146
self.insts = []
130147
self.successors = set()
148+
131149
def add(self, inst):
132150
self.insts.append(inst)
151+
133152
def end(self):
134153
inst = self.insts[-1]
135154
return inst.is_jmp() or inst.op == 'ret'
136155

156+
137157
class Instruction:
138158
def __init__(self, op, args):
139159
self.op = op
140160
self.args = args
161+
141162
def __str__(self):
142163
return f'{self.op} {", ".join([str(arg) for arg in self.args if str(arg)])}'
164+
143165
@classmethod
144166
def load(cls, text):
145-
text = text.strip().strip('bnd').strip() # get rid of BND prefix
167+
text = text.strip().strip('bnd').strip()
146168
op, _, args = text.strip().partition(' ')
147169
if args:
148170
args = [arg.strip() for arg in args.split(',')]
149171
else:
150172
args = []
151173
args = (args + ['', ''])[:2]
152174
return cls(op, args)
175+
153176
def tokens(self):
154177
return [self.op] + self.args
178+
155179
def is_jmp(self):
156180
return 'jmp' in self.op or self.op[0] == 'j'
181+
157182
def is_call(self):
158183
return self.op == 'call'

asm2vec/model.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,50 @@
33

44
bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1)
55

6+
67
class ASM2VEC(nn.Module):
78
def __init__(self, vocab_size, function_size, embedding_size):
89
super(ASM2VEC, self).__init__()
9-
self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size))
10-
self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2)
11-
self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
10+
self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size))
11+
self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size,
12+
_weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2)
13+
self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size,
14+
_weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
1215

1316
def update(self, function_size_new, vocab_size_new):
1417
device = self.embeddings.weight.device
15-
vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim
18+
vocab_size, function_size, embedding_size = (self.embeddings.num_embeddings,
19+
self.embeddings_f.num_embeddings, self.embeddings.embedding_dim)
1620
if vocab_size_new != vocab_size:
17-
weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)])
21+
weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).
22+
to(device)])
1823
self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight)
19-
weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)])
24+
weight_r = torch.cat([self.embeddings_r.weight,
25+
((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
26+
.to(device)])
2027
self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r)
21-
self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device))
28+
self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size,
29+
_weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5) /
30+
embedding_size/2).to(device))
2231

2332
def v(self, inp):
24-
e = self.embeddings(inp[:,1:])
25-
v_f = self.embeddings_f(inp[:,0])
26-
v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1)
27-
v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1)
33+
e = self.embeddings(inp[:, 1:])
34+
v_f = self.embeddings_f(inp[:, 0])
35+
v_prev = torch.cat([e[:, 0], (e[:, 1] + e[:, 2]) / 2], dim=1)
36+
v_next = torch.cat([e[:, 3], (e[:, 4] + e[:, 5]) / 2], dim=1)
2837
v = ((v_f + v_prev + v_next) / 3).unsqueeze(2)
2938
return v
3039

3140
def forward(self, inp, pos, neg):
3241
device, batch_size = inp.device, inp.shape[0]
3342
v = self.v(inp)
34-
# negative sampling loss
3543
pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze()
3644
label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device)
3745
return bce(sigmoid(pred), label)
3846

3947
def predict(self, inp, pos):
4048
device, batch_size = inp.device, inp.shape[0]
4149
v = self.v(inp)
42-
probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2)
50+
probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).
51+
to(device)), v).squeeze(dim=2)
4352
return softmax(probs)

0 commit comments

Comments
 (0)