diff --git a/internal/decoder/data_decoder.go b/internal/decoder/data_decoder.go index fb65134..06958e7 100644 --- a/internal/decoder/data_decoder.go +++ b/internal/decoder/data_decoder.go @@ -424,28 +424,36 @@ func (d *DataDecoder) decodeKey(offset uint) ([]byte, uint, error) { // the one at the offset passed in. The size bits have different meanings for // different data types. func (d *DataDecoder) nextValueOffset(offset, numberToSkip uint) (uint, error) { - if numberToSkip == 0 { - return offset, nil - } - kindNum, size, offset, err := d.decodeCtrlData(offset) - if err != nil { - return 0, err - } - switch kindNum { - case KindPointer: - _, offset, err = d.decodePointer(size, offset) + for numberToSkip > 0 { + kindNum, size, newOffset, err := d.decodeCtrlData(offset) if err != nil { return 0, err } - case KindMap: - numberToSkip += 2 * size - case KindSlice: - numberToSkip += size - case KindBool: - default: - offset += size + + switch kindNum { + case KindPointer: + // A pointer value is represented by its pointer token only. + // To skip it, just move past the pointer bytes; do NOT follow + // the pointer target here. + _, ptrEndOffset, err2 := d.decodePointer(size, newOffset) + if err2 != nil { + return 0, err2 + } + newOffset = ptrEndOffset + case KindMap: + numberToSkip += 2 * size + case KindSlice: + numberToSkip += size + case KindBool: + // size encodes the boolean; nothing else to skip + default: + newOffset += size + } + + offset = newOffset + numberToSkip-- } - return d.nextValueOffset(offset, numberToSkip-1) + return offset, nil } func (d *DataDecoder) sizeFromCtrlByte( diff --git a/reader.go b/reader.go index 0c517da..7429532 100644 --- a/reader.go +++ b/reader.go @@ -515,6 +515,67 @@ func readNodeBySize(buffer []byte, offset, bit, recordSize uint) (uint, error) { } } +// readNodePairBySize reads both left (bit=0) and right (bit=1) child pointers +// for a node at the given base offset according to the record size. This reduces +// duplicate bound checks and byte fetches when both children are needed. +func readNodePairBySize(buffer []byte, baseOffset, recordSize uint) (left, right uint, err error) { + bufferLen := uint(len(buffer)) + switch recordSize { + case 24: + // Each child is 3 bytes; total 6 bytes starting at baseOffset + if baseOffset > bufferLen-6 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 24-bit node pair read", + ) + } + o := baseOffset + left = (uint(buffer[o]) << 16) | (uint(buffer[o+1]) << 8) | uint(buffer[o+2]) + o += 3 + right = (uint(buffer[o]) << 16) | (uint(buffer[o+1]) << 8) | uint(buffer[o+2]) + return left, right, nil + case 28: + // Left uses high nibble of shared byte, right uses low nibble. + // Layout: [A B C S][D E F] where S provides 4 shared bits for each child + if baseOffset > bufferLen-7 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 28-bit node pair read", + ) + } + // Left child (bit=0): uses high nibble of shared byte + shared := uint(buffer[baseOffset+3]) + left = ((shared & 0xF0) << 20) | + (uint(buffer[baseOffset]) << 16) | + (uint(buffer[baseOffset+1]) << 8) | + uint(buffer[baseOffset+2]) + // Right child (bit=1): uses low nibble of shared byte, next 3 bytes + right = ((shared & 0x0F) << 24) | + (uint(buffer[baseOffset+4]) << 16) | + (uint(buffer[baseOffset+5]) << 8) | + uint(buffer[baseOffset+6]) + return left, right, nil + case 32: + // Each child is 4 bytes; total 8 bytes + if baseOffset > bufferLen-8 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 32-bit node pair read", + ) + } + o := baseOffset + left = (uint(buffer[o]) << 24) | + (uint(buffer[o+1]) << 16) | + (uint(buffer[o+2]) << 8) | + uint(buffer[o+3]) + o += 4 + right = (uint(buffer[o]) << 24) | + (uint(buffer[o+1]) << 16) | + (uint(buffer[o+2]) << 8) | + uint(buffer[o+3]) + return left, right, nil + default: + return 0, 0, mmdberrors.NewInvalidDatabaseError("unsupported record size") + } +} + func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, error) { switch r.Metadata.RecordSize { case 24: diff --git a/traverse.go b/traverse.go index 5559818..1a49eb3 100644 --- a/traverse.go +++ b/traverse.go @@ -223,8 +223,12 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) } ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) - offset := node.pointer * r.nodeOffsetMult - rightPointer, err := readNodeBySize(r.buffer, offset, 1, r.Metadata.RecordSize) + baseOffset := node.pointer * r.nodeOffsetMult + leftPointer, rightPointer, err := readNodePairBySize( + r.buffer, + baseOffset, + r.Metadata.RecordSize, + ) if err != nil { yield(Result{ ip: mappedIP(node.ip), @@ -241,15 +245,6 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) bit: node.bit, }) - leftPointer, err := readNodeBySize(r.buffer, offset, 0, r.Metadata.RecordSize) - if err != nil { - yield(Result{ - ip: mappedIP(node.ip), - prefixLen: uint8(node.bit), - err: err, - }) - return - } node.pointer = leftPointer } }