Skip to content

Commit 8cbbd4c

Browse files
committed
1.6: Improve heap safety in allocator
1 parent 171c976 commit 8cbbd4c

File tree

1 file changed

+92
-52
lines changed

1 file changed

+92
-52
lines changed

Client/multiplayer_sa/CMultiplayerSA_FixMallocAlign.cpp

Lines changed: 92 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ namespace memory
2222
constexpr std::uint32_t NULL_PAGE_BOUNDARY = 0x10000;
2323
constexpr std::uint32_t MAX_ADDRESS_SPACE = 0xFFFFFFFF;
2424
constexpr std::uint32_t POINTER_SIZE = 4;
25+
constexpr std::uint32_t POINTER_METADATA_OVERHEAD = POINTER_SIZE * 2;
26+
constexpr std::uint32_t METADATA_MAGIC = 0x4D544100; // 'MTA\0'
27+
constexpr std::uint32_t METADATA_MAGIC_MASK = 0xFFFFFFFE;
28+
constexpr std::uint32_t METADATA_FLAG_VIRTUALALLOC = 0x1;
2529

2630
constexpr bool is_valid_alignment(std::size_t alignment) noexcept
2731
{
@@ -30,7 +34,7 @@ namespace memory
3034

3135
void* SafeMallocAlignVirtual(std::size_t size, std::size_t alignment) noexcept;
3236

33-
// Aligned malloc - stores pointer at result-4
37+
// Aligned malloc - stores pointer at result-4 and metadata at result-8
3438
void* SafeMallocAlign(std::size_t size, std::size_t alignment) noexcept
3539
{
3640
// Check alignment
@@ -58,18 +62,19 @@ namespace memory
5862
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
5963

6064
// Prevent intermediate overflow
61-
if (size_u32 > UINT32_MAX - align_u32)
65+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
6266
{
6367
errno = ENOMEM;
6468
return nullptr;
6569
}
66-
// Now safe to add size_u32 + align_u32
67-
if (size_u32 + align_u32 > UINT32_MAX - POINTER_SIZE)
70+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
71+
72+
if (size_u32 > UINT32_MAX - alignment_overhead)
6873
{
6974
errno = ENOMEM;
7075
return nullptr;
7176
}
72-
const std::uint32_t total_size = size_u32 + align_u32 + POINTER_SIZE;
77+
const std::uint32_t total_size = size_u32 + alignment_overhead;
7378

7479
void* raw_memory = malloc(total_size);
7580
if (!raw_memory)
@@ -80,16 +85,16 @@ namespace memory
8085

8186
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_memory);
8287

83-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
88+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
8489
{
8590
free(raw_memory);
8691
errno = ENOMEM;
8792
return nullptr;
8893
}
8994

90-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
95+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
9196

92-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
97+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
9398
{
9499
free(raw_memory);
95100
errno = EINVAL;
@@ -100,15 +105,20 @@ namespace memory
100105

101106
// Validate store location
102107
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
103-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
104-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
108+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
109+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
110+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
111+
112+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
113+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
105114
{
106115
free(raw_memory);
107116
errno = EFAULT;
108117
return nullptr;
109118
}
110119

111120
*store_location = raw_memory;
121+
*metadata_location = METADATA_MAGIC;
112122

113123
return result;
114124
}
@@ -129,18 +139,19 @@ namespace memory
129139
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
130140

131141
// Prevent intermediate overflow
132-
if (size_u32 > UINT32_MAX - align_u32)
142+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
133143
{
134144
errno = ENOMEM;
135145
return nullptr;
136146
}
137-
// Now safe to add size_u32 + align_u32
138-
if (size_u32 + align_u32 > UINT32_MAX - POINTER_SIZE)
147+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
148+
149+
if (size_u32 > UINT32_MAX - alignment_overhead)
139150
{
140151
errno = ENOMEM;
141152
return nullptr;
142153
}
143-
const std::uint32_t total_size = size_u32 + align_u32 + POINTER_SIZE;
154+
const std::uint32_t total_size = size_u32 + alignment_overhead;
144155

145156
void* raw_memory = malloc(total_size);
146157
if (!raw_memory)
@@ -151,16 +162,16 @@ namespace memory
151162

152163
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_memory);
153164

154-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
165+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
155166
{
156167
free(raw_memory);
157168
errno = ENOMEM;
158169
return nullptr;
159170
}
160171

161-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
172+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
162173

163-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
174+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
164175
{
165176
free(raw_memory);
166177
errno = EINVAL;
@@ -170,15 +181,20 @@ namespace memory
170181
void* result = reinterpret_cast<void*>(aligned_addr);
171182

172183
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
173-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
174-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
184+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
185+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
186+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
187+
188+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
189+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
175190
{
176191
free(raw_memory);
177192
errno = EFAULT;
178193
return nullptr;
179194
}
180195

181196
*store_location = raw_memory;
197+
*metadata_location = METADATA_MAGIC;
182198

183199
return result;
184200
}
@@ -208,23 +224,25 @@ namespace memory
208224
const std::uint32_t align_u32 = static_cast<std::uint32_t>(alignment);
209225
const std::uint32_t padding = (align_u32 <= 64) ? 32 : VIRTUALALLOC_PADDING;
210226

211-
if (align_u32 > UINT32_MAX - POINTER_SIZE)
227+
if (align_u32 > UINT32_MAX - POINTER_METADATA_OVERHEAD)
212228
{
213229
errno = ENOMEM;
214230
return nullptr;
215231
}
216-
if (align_u32 + POINTER_SIZE > UINT32_MAX - padding)
232+
const std::uint32_t alignment_overhead = align_u32 + POINTER_METADATA_OVERHEAD;
233+
234+
if (alignment_overhead > UINT32_MAX - padding)
217235
{
218236
errno = ENOMEM;
219237
return nullptr;
220238
}
221-
if (size_u32 > UINT32_MAX - align_u32 - POINTER_SIZE - padding)
239+
if (size_u32 > UINT32_MAX - alignment_overhead - padding)
222240
{
223241
errno = ENOMEM;
224242
return nullptr;
225243
}
226244

227-
const DWORD total_size = size_u32 + align_u32 + POINTER_SIZE + padding;
245+
const DWORD total_size = size_u32 + alignment_overhead + padding;
228246

229247
void* raw_ptr = VirtualAlloc(nullptr, static_cast<SIZE_T>(total_size), MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
230248
if (!raw_ptr)
@@ -235,17 +253,17 @@ namespace memory
235253

236254
const std::uint32_t raw_addr = reinterpret_cast<std::uint32_t>(raw_ptr);
237255

238-
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_SIZE - align_u32 + 1)
256+
if (raw_addr > MAX_ADDRESS_SPACE - POINTER_METADATA_OVERHEAD - align_u32 + 1)
239257
{
240258
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
241259
(void)vfree_result;
242260
errno = ENOMEM;
243261
return nullptr;
244262
}
245263

246-
const std::uint32_t aligned_addr = (raw_addr + POINTER_SIZE + align_u32 - 1) & ~(align_u32 - 1);
264+
const std::uint32_t aligned_addr = (raw_addr + POINTER_METADATA_OVERHEAD + align_u32 - 1) & ~(align_u32 - 1);
247265

248-
if (aligned_addr < raw_addr + POINTER_SIZE || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
266+
if (aligned_addr < raw_addr + POINTER_METADATA_OVERHEAD || aligned_addr + size_u32 > raw_addr + total_size || aligned_addr + size_u32 < aligned_addr)
249267
{
250268
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
251269
(void)vfree_result;
@@ -257,8 +275,12 @@ namespace memory
257275

258276
// Validate store location
259277
void** store_location = reinterpret_cast<void**>(aligned_addr - POINTER_SIZE);
260-
if (reinterpret_cast<std::uint32_t>(store_location) < raw_addr ||
261-
reinterpret_cast<std::uint32_t>(store_location) > raw_addr + total_size - POINTER_SIZE)
278+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(aligned_addr - POINTER_METADATA_OVERHEAD);
279+
const std::uint32_t store_addr = reinterpret_cast<std::uint32_t>(store_location);
280+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
281+
282+
if (store_addr < raw_addr || store_addr > raw_addr + total_size - POINTER_SIZE ||
283+
metadata_addr < raw_addr || metadata_addr > raw_addr + total_size - POINTER_SIZE)
262284
{
263285
BOOL vfree_result = VirtualFree(raw_ptr, 0, MEM_RELEASE);
264286
(void)vfree_result;
@@ -267,6 +289,7 @@ namespace memory
267289
}
268290

269291
*store_location = raw_ptr;
292+
*metadata_location = METADATA_MAGIC | METADATA_FLAG_VIRTUALALLOC;
270293

271294
return result;
272295
}
@@ -284,7 +307,13 @@ namespace memory
284307
return;
285308
}
286309

310+
if (ptr_addr < POINTER_METADATA_OVERHEAD)
311+
{
312+
return;
313+
}
314+
287315
void** read_location = reinterpret_cast<void**>(ptr_addr - POINTER_SIZE);
316+
std::uint32_t* metadata_location = reinterpret_cast<std::uint32_t*>(ptr_addr - POINTER_METADATA_OVERHEAD);
288317

289318
// Validate memory readable
290319
MEMORY_BASIC_INFORMATION mbi_read;
@@ -295,7 +324,35 @@ namespace memory
295324
return;
296325
}
297326

327+
const std::uint32_t metadata_addr = reinterpret_cast<std::uint32_t>(metadata_location);
328+
const std::uint32_t base_addr = reinterpret_cast<std::uint32_t>(mbi_read.BaseAddress);
329+
330+
if (mbi_read.RegionSize == 0 || mbi_read.RegionSize > static_cast<SIZE_T>(MAX_ADDRESS_SPACE))
331+
{
332+
return;
333+
}
334+
335+
const std::uint32_t region_size_u32 = static_cast<std::uint32_t>(mbi_read.RegionSize);
336+
337+
if (base_addr > MAX_ADDRESS_SPACE - region_size_u32)
338+
{
339+
return;
340+
}
341+
342+
const std::uint32_t region_end = base_addr + region_size_u32;
343+
344+
if (region_size_u32 < POINTER_SIZE || metadata_addr < base_addr || metadata_addr > region_end - POINTER_SIZE)
345+
{
346+
return;
347+
}
348+
298349
void* original_ptr = *read_location;
350+
const std::uint32_t metadata = *metadata_location;
351+
352+
if ((metadata & METADATA_MAGIC_MASK) != METADATA_MAGIC)
353+
{
354+
return;
355+
}
299356

300357
if (!original_ptr)
301358
{
@@ -310,12 +367,12 @@ namespace memory
310367
}
311368

312369
const std::uint32_t distance = ptr_addr - original_addr;
313-
if (distance > MAX_ALIGNMENT + POINTER_SIZE)
370+
if (distance > MAX_ALIGNMENT + POINTER_METADATA_OVERHEAD)
314371
{
315372
return; // Beyond maximum possible alignment
316373
}
317374

318-
if (ptr_addr < POINTER_SIZE || original_addr > ptr_addr - POINTER_SIZE)
375+
if (original_addr > ptr_addr - POINTER_SIZE)
319376
{
320377
return; // Violates our storage pattern
321378
}
@@ -325,31 +382,13 @@ namespace memory
325382
return;
326383
}
327384

328-
MEMORY_BASIC_INFORMATION mbi;
329-
SIZE_T mbi_result = VirtualQuery(original_ptr, &mbi, sizeof(mbi));
330-
331-
if (mbi_result == sizeof(mbi))
385+
if ((metadata & METADATA_FLAG_VIRTUALALLOC) != 0)
332386
{
333-
const std::uint32_t base_addr = reinterpret_cast<std::uint32_t>(mbi.AllocationBase);
334-
335-
// Validate region size
336-
if (mbi.RegionSize > 0 && mbi.RegionSize <= static_cast<SIZE_T>(MAX_ADDRESS_SPACE) &&
337-
base_addr <= MAX_ADDRESS_SPACE - static_cast<std::uint32_t>(mbi.RegionSize))
338-
{
339-
const std::uint32_t region_size_u32 = static_cast<std::uint32_t>(mbi.RegionSize);
340-
const std::uint32_t region_end = base_addr + region_size_u32;
341-
342-
// Use VirtualFree if matches
343-
if (mbi.Type == MEM_PRIVATE && mbi.State == MEM_COMMIT && original_addr >= base_addr && original_addr < region_end)
344-
{
345-
BOOL vfree_result = VirtualFree(mbi.AllocationBase, 0, MEM_RELEASE);
346-
(void)vfree_result;
347-
return;
348-
}
349-
}
387+
BOOL vfree_result = VirtualFree(original_ptr, 0, MEM_RELEASE);
388+
(void)vfree_result;
389+
return;
350390
}
351391

352-
// Use free for malloc
353392
free(original_ptr);
354393
}
355394
} // namespace memory
@@ -400,3 +439,4 @@ void CMultiplayerSA::InitHooks_FixMallocAlign()
400439
HookInstall(HOOKPOS_CMemoryMgr_MallocAlign, reinterpret_cast<DWORD>(HOOK_CMemoryMgr_MallocAlign), HOOKSIZE_CMemoryMgr_MallocAlign);
401440
HookInstall(HOOKPOS_CMemoryMgr_FreeAlign, reinterpret_cast<DWORD>(HOOK_CMemoryMgr_FreeAlign), HOOKSIZE_CMemoryMgr_FreeAlign);
402441
}
442+

0 commit comments

Comments
 (0)