local ipairs, tonumber, tostring, type = ipairs, tonumber, tostring, type local bit = require("bit") local tobit = bit.tobit local lshift = bit.lshift local band = bit.band local bor = bit.bor local xor = bit.bxor local byte = string.byte local str_find = string.find local str_sub = string.sub local lrucache = nil local _M = { _VERSION = '0.02', } local mt = { __index = _M } -- Precompute binary subnet masks... local bin_masks = {} for i=1,32 do bin_masks[tostring(i)] = lshift(tobit((2^i)-1), 32-i) end -- ... and their inverted counterparts local bin_inverted_masks = {} for i=1,32 do local i = tostring(i) bin_inverted_masks[i] = xor(bin_masks[i], bin_masks["32"]) end local log_err if ngx then log_err = function(...) ngx.log(ngx.ERR, ...) end else log_err = function(...) print(...) end end local function enable_lrucache(size) local size = size or 4000 -- Cache the last 4000 IPs (~1MB memory) by default local lrucache_obj, err = require("resty.lrucache").new(4000) if not lrucache_obj then return nil, "failed to create the cache: " .. (err or "unknown") end lrucache = lrucache_obj return true end _M.enable_lrucache = enable_lrucache local function split_octets(input) local pos = 0 local prev = 0 local octs = {} for i=1, 4 do pos = str_find(input, ".", prev, true) if pos then if i == 4 then -- Should not have a match after 4 octets return nil, "Invalid IP" end octs[i] = str_sub(input, prev, pos-1) elseif i == 4 then -- Last octet, get everything to the end octs[i] = str_sub(input, prev, -1) break else return nil, "Invalid IP" end prev = pos +1 end return octs end local function ip2bin(ip) if lrucache then local get = lrucache:get(ip) if get then return get[1], get[2] end end if type(ip) ~= "string" then return nil, "IP must be a string" end local octets = split_octets(ip) if not octets or #octets ~= 4 then return nil, "Invalid IP" end -- Return the binary representation of an IP and a table of binary octets local bin_octets = {} local bin_ip = 0 for i,octet in ipairs(octets) do local bin_octet = tonumber(octet) if not bin_octet or bin_octet > 255 then return nil, "Invalid octet: "..tostring(octet) end bin_octet = tobit(bin_octet) bin_octets[i] = bin_octet bin_ip = bor(lshift(bin_octet, 8*(4-i) ), bin_ip) end if lrucache then lrucache:set(ip, {bin_ip, bin_octets}) end return bin_ip, bin_octets end _M.ip2bin = ip2bin local function split_cidr(input) local pos = str_find(input, "/", 0, true) if not pos then return {input} end return {str_sub(input, 1, pos-1), str_sub(input, pos+1, -1)} end local function parse_cidr(cidr) local mask_split = split_cidr(cidr, '/') local net = mask_split[1] local mask = mask_split[2] or "32" local mask_num = tonumber(mask) if not mask_num or (mask_num > 32 or mask_num < 1) then return nil, "Invalid prefix: /"..tostring(mask) end local bin_net, err = ip2bin(net) -- Convert IP to binary if not bin_net then return nil, err end local bin_mask = bin_masks[mask] -- Get masks local bin_inv_mask = bin_inverted_masks[mask] local lower = band(bin_net, bin_mask) -- Network address local upper = bor(lower, bin_inv_mask) -- Broadcast address return lower, upper end _M.parse_cidr = parse_cidr local function parse_cidrs(cidrs) local out = {} local i = 1 for _,cidr in ipairs(cidrs) do local lower, upper = parse_cidr(cidr) if not lower then log_err("Error parsing '", cidr, "': ", upper) else out[i] = {lower, upper} i = i+1 end end return out end _M.parse_cidrs = parse_cidrs local function ip_in_cidrs(ip, cidrs) local bin_ip, bin_octets = ip2bin(ip) if not bin_ip then return nil, bin_octets end for _,cidr in ipairs(cidrs) do if bin_ip >= cidr[1] and bin_ip <= cidr[2] then return true end end return false end _M.ip_in_cidrs = ip_in_cidrs local function binip_in_cidrs(bin_ip_ngx, cidrs) if 4 ~= #bin_ip_ngx then return false, "invalid IP address" end local bin_ip = 0 for i=1,4 do bin_ip = bor(lshift(bin_ip, 8), tobit(byte(bin_ip_ngx, i))) end for _,cidr in ipairs(cidrs) do if bin_ip >= cidr[1] and bin_ip <= cidr[2] then return true end end return false end _M.binip_in_cidrs = binip_in_cidrs return _M