Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions mmdb_writer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# coding: utf-8
__version__ = '0.1.0'
__version__ = '0.1.1'

import logging
import math
import struct
import time
from typing import Union

from netaddr import IPSet
from netaddr import IPSet, IPNetwork

MMDBType = Union[dict, list, str, bytes, int, bool]

Expand Down Expand Up @@ -426,18 +426,24 @@ def insert_network(self, network: IPSet, content: MMDBType):
cidr = cidr.ipv6(True)
node = self.tree
bits = list(bits_rstrip(cidr.value, self._bit_length, cidr.prefixlen))
try:
for i in bits[:-1]:
node = node.get_or_create(i)
if node[bits[-1]] is not None:
logger.warning("address %s info is not empty: %s, will override with %s",
cidr, node[bits[-1]], leaf)
except (AttributeError, TypeError) as e:
bits_str = ''.join(map(str, bits))
logger.warning("{cidr}({bits_str})[{content}] is subnet of {node}, pass!"
.format(cidr=cidr, bits_str=bits_str, content=content, node=node))
continue
node[bits[-1]] = leaf
current_node = node
supernet_leaf = None # Tracks whether we are inserting into a subnet
for (index, ip_bit) in enumerate(bits[:-1]):
previous_node = current_node
current_node = previous_node.get_or_create(ip_bit)

if isinstance(current_node, SearchTreeLeaf):
current_cidr = IPNetwork((int(''.join(map(str, bits[:index + 1])).ljust(self._bit_length, '0'), 2), index + 1))
logger.info(f"Inserting {cidr} ({content}) into subnet of {current_cidr} ({current_node.value})")
supernet_leaf = current_node
current_node = SearchTreeNode()
previous_node[ip_bit] = current_node

if supernet_leaf:
next_bit = bits[index + 1]
# Insert supernet information on each inverse bit of the current subnet
current_node[1 - next_bit] = supernet_leaf
current_node[bits[-1]] = leaf

def to_db_file(self, filename: str):
return TreeWriter(self.tree, self._build_meta()).write(filename)
Expand Down
3 changes: 0 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
# coding: utf-8

# TODO: add tests
80 changes: 80 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# coding: utf-8
import logging
import os.path
import unittest

import maxminddb
from netaddr import IPSet

from mmdb_writer import MMDBWriter

logging.basicConfig(format="[%(asctime)s: %(levelname)s] %(message)s", level=logging.INFO)
info1 = {'country': 'c1', 'isp': 'ISP1'}
info2 = {'country': 'c2', 'isp': 'ISP2'}


class TestBuild(unittest.TestCase):
def setUp(self) -> None:
self.filename = '_test.mmdb'

def tearDown(self) -> None:
if os.path.exists(self.filename):
os.remove(self.filename)

def test_metadata(self):
ip_version = 6
database_type = 'test_database_type'
languages = ['en', 'ch']
description = {'en': 'en test', 'ch': 'ch test'}
writer = MMDBWriter(ip_version=ip_version, database_type=database_type,
languages=languages, description=description,
ipv4_compatible=False)
writer.to_db_file(self.filename)
for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE):
m = maxminddb.open_database(self.filename, mode=mode)
self.assertEqual(ip_version, m.metadata().ip_version, mode)
self.assertEqual(database_type, m.metadata().database_type, mode)
self.assertEqual(languages, m.metadata().languages, mode)
self.assertEqual(description, m.metadata().description, mode)
m.close()

def test_encode_type(self):
writer = MMDBWriter()
info = {'int': 1, 'float': 1.0 / 3, 'list': ['a', 'b', 'c'], 'dict': {'k': 'v'}, 'bytes': b'bytes', 'str': 'str'}
writer.insert_network(IPSet(['1.1.0.0/24']), info)
writer.to_db_file(self.filename)
for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE):
m = maxminddb.open_database(self.filename, mode=mode)
get = m.get('1.1.0.255')
self.assertEqual(len(info), len(get), mode)
self.assertEqual(info['int'], get['int'], mode)
self.assertTrue(abs(info['float'] - get['float']) < 1e-5, mode)
self.assertEqual(info['list'], get['list'], mode)
self.assertEqual(info['dict'], get['dict'], mode)
self.assertEqual(info['bytes'], get['bytes'], mode)
self.assertEqual(info['str'], get['str'], mode)
m.close()

def test_4in6(self):
writer = MMDBWriter(ip_version=6, ipv4_compatible=True)
writer.insert_network(IPSet(['1.1.0.0/24']), info1)
writer.insert_network(IPSet(['fe80::/16']), info2)
writer.to_db_file(self.filename)
for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE):
m = maxminddb.open_database(self.filename, mode=mode)
self.assertEqual(info1, m.get('1.1.0.1'), mode)
self.assertEqual(info2, m.get('fe80::1'), mode)
m.close()

def test_insert_subnet(self):
writer = MMDBWriter()
writer.insert_network(IPSet(['1.0.0.0/8']), info1)
writer.insert_network(IPSet(['1.10.10.0/24']), info2)
writer.to_db_file(self.filename)
for mode in (maxminddb.MODE_MMAP_EXT, maxminddb.MODE_MMAP, maxminddb.MODE_FILE):
m = maxminddb.open_database(self.filename, mode=mode)
self.assertEqual(info1, m.get('1.1.0.1'), mode)
self.assertEqual(info1, m.get('1.10.0.1'), mode)
self.assertEqual(info2, m.get('1.10.10.1'), mode)
m.close()