policy.lua 8.56 KB
Newer Older
1
local kres = require('kres')
2

3 4 5 6 7 8 9 10
-- Counter of unique rules
local nextid = 0
local function getruleid()
	local newid = nextid
	nextid = nextid + 1
	return newid
end

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
-- Support for client sockets from inside policy actions
local socket_client = function () return error("missing luasocket, can't create socket client") end
local has_socket, socket = pcall(require, 'socket')
if has_socket then
	socket_client = function (host, port)
		local s, err, status
		if host:find(':') then
			s, err = socket.udp6()
		else
			s, err = socket.udp()
		end
		if not s then
			return nil, err
		end
		status, err = s:setpeername(host, port)
		if not status then
			return nil, err
		end
		return s
	end
end
local has_ffi, ffi = pcall(require, 'ffi')
if not has_ffi then
	socket_client = function () return error("missing ffi library, required for this policy") end
end

-- Mirror request elsewhere, and continue solving
local function mirror(target)
	local addr, port = target:match '([^@]*)@?(.*)'
	if not port or #port == 0 then port = 53 end
	local sink, err = socket_client(addr, port)
	if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
	return function(state, req)
		if state == kres.FAIL then return state end
		req = kres.request_t(req)
		local query = req.qsource.packet
		if query ~= nil then
			sink:send(ffi.string(query.wire, query.size))
		end
		return -- Chain action to next
	end
end

54 55 56 57 58 59 60 61 62 63 64 65 66
-- Forward request, and solve as stub query
local function forward(target)
	local dst_ip = kres.str2ip(target)
	if dst_ip == nil then error("FORWARD target '"..target..'" is not a valid IP address') end
	return function(state, req)
		req = kres.request_t(req)
		local qry = req:current()
		qry.flags = qry.flags + kres.query.STUB
		qry:nslist(dst_ip)
		return state
	end
end

67 68 69 70 71 72 73 74 75 76 77 78 79
-- Rewrite records in packet
local function reroute(tbl, names)
	-- Import renumbering rules
	local ren = require('renumber')
	local prefixes = {}
	for from, to in pairs(tbl) do
		table.insert(prefixes, names and ren.name(from, to) or ren.prefix(from, to))
	end
	-- Return rule closure
	tbl = nil
	return ren.rule(prefixes)
end

80
local policy = {
81
	-- Policies
82
	PASS = 1, DENY = 2, DROP = 3, TC = 4, FORWARD = forward, REROUTE = reroute, MIRROR = mirror,
83 84 85 86
	-- Special values
	ANY = 0,
}

87 88 89 90 91
-- All requests
function policy.all(action)
	return function(req, query) return action end
end

92
-- Requests which QNAME matches given zone list (i.e. suffix match)
93
function policy.suffix(action, zone_list)
94 95
	local AC = require('aho-corasick')
	local tree = AC.build(zone_list)
96 97
	return function(req, query)
		local match = AC.match(tree, query:name(), false)
98
		if match[1] ~= nil then
99
			return action
100 101 102 103 104
		end
		return nil
	end
end

105
-- Check for common suffix first, then suffix match (specialized version of suffix match)
106
function policy.suffix_common(action, suffix_list, common_suffix)
107
	local common_len = string.len(common_suffix)
108
	local suffix_count = #suffix_list
109
	return function(req, query)
110
		-- Preliminary check
111
		local qname = query:name()
112 113
		if not string.find(qname, common_suffix, -common_len, true) then
			return nil
114 115
		end
		-- String match
116
		for i = 1, suffix_count do
117
			local zone = suffix_list[i]
118
			if string.find(qname, zone, -string.len(zone), true) then
119
				return action
120 121 122 123 124 125
			end
		end
		return nil
	end
end

126
-- Filter QNAME pattern
127
function policy.pattern(action, pattern)
128 129
	return function(req, query)
		if string.find(query:name(), pattern) then
130
			return action
131 132 133 134 135
		end
		return nil
	end
end

136 137 138 139 140 141 142 143 144 145 146 147
local function rpz_parse(action, path)
	local rules = {}
	local ffi = require('ffi')
	local action_map = {
		-- RPZ Policy Actions
		['\0'] = action,
		['\1*\0'] = action, -- deviates from RPZ spec
		['\012rpz-passthru\0'] = policy.PASS, -- the grammar...
		['\008rpz-drop\0'] = policy.DROP,
		['\012rpz-tcp-only\0'] = policy.TC,
		-- Policy triggers @NYI@
	}
148 149 150 151 152
	local parser = require('zonefile').new()
	if not parser:open(path) then error(string.format('failed to parse "%s"', path)) end
	while parser:parse() do
		local name = ffi.string(parser.r_owner, parser.r_owner_length)
		local action = ffi.string(parser.r_data, parser.r_data_length)
153
		rules[name] = action_map[action]
Marek Vavrusa's avatar
Marek Vavrusa committed
154 155 156 157
		-- Warn when NYI
		if #name > 1 and not action_map[action] then
			print(string.format('[ rpz ] %s:%d: unsupported policy action', path, tonumber(parser.line_counter)))
		end
158
	end
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
	return rules
end

-- Create RPZ from zone file
local function rpz_zonefile(action, path)
	local rules = rpz_parse(action, path)
	collectgarbage()
	return function(req, query)
		local label = query:name()
		local action = rules[label]
		while action == nil and string.len(label) > 0 do
			label = string.sub(label, string.byte(label) + 2)
			action = rules['\1*'..label]
		end
		return action
	end
end

-- RPZ policy set
function policy.rpz(action, path, format)
	if format == 'lmdb' then
		error('lmdb zone format is NYI')
	else
		return rpz_zonefile(action, path)
	end
end

-- Evaluate packet in given rules to determine policy action
187
function policy.evaluate(rules, req, query, state)
188
	for i = 1, #rules do
189
		local rule = rules[i]
190 191 192 193 194 195 196 197 198
		if not rule.suspended then
			local action = rule.cb(req, query)
			if action ~= nil then
				rule.count = rule.count + 1
				local next_state = policy.enforce(state, req, action)
				if next_state then    -- Not a chain rule,
					return next_state -- stop on first match
				end
			end
199 200
		end
	end
201
	return state
202 203
end

204
-- Enforce policy action
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
function policy.enforce(state, req, action)
	if action == policy.DENY then
		-- Write authority information
		local answer = req.answer
		answer:rcode(kres.rcode.NXDOMAIN)
		answer:begin(kres.section.AUTHORITY)
		answer:put('\7blocked', 900, answer:qclass(), kres.type.SOA,
			'\7blocked\0\0\0\0\0\0\0\0\14\16\0\0\3\132\0\9\58\128\0\0\3\132')
		return kres.DONE
	elseif action == policy.DROP then
		return kres.FAIL
	elseif action == policy.TC then
		local answer = req.answer
		if answer.max_size ~= 65535 then
			answer:tc(1) -- ^ Only UDP queries
			return kres.DONE
		end
222 223
	elseif type(action) == 'function' then
		return action(state, req)
224 225 226 227
	end
	return state
end

228 229 230 231 232
-- Top-down policy list walk until we hit a match
-- the caller is responsible for reordering policy list
-- from most specific to least specific.
-- Some rules may be chained, in this case they are evaluated
-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state)))
233
policy.layer = {
234 235
	begin = function(state, req)
		req = kres.request_t(req)
236
		return policy.evaluate(policy.rules, req, req:current(), state)
237 238 239
	end,
	finish = function(state, req)
		req = kres.request_t(req)
240
		return policy.evaluate(policy.postrules, req, req:current(), state)
241 242 243
	end
}

244
-- Add rule to policy list
245
function policy.add(rule, postrule)
246 247 248 249 250 251 252
	-- Compatibility with 1.0.0 API
	-- it will be dropped in 1.2.0
	if rule == policy then
		rule = postrule
		postrule = nil
	end
	-- End of compatibility shim
253 254 255
	local desc = {id=getruleid(), cb=rule, count=0}
	table.insert(postrule and policy.postrules or policy.rules, desc)
	return desc
256
end
257

258 259 260 261 262 263 264 265 266 267 268 269
-- Remove rule from a list
local function delrule(rules, id)
	for i, r in ipairs(rules) do
		if r.id == id then
			table.remove(rules, i)
			return true
		end
	end
	return false
end

-- Delete rule from policy list
270
function policy.del(id)
271 272 273 274 275 276 277 278
	if not delrule(policy.rules, id) then
		if not delrule(policy.postrules, id) then
			return false
		end
	end
	return true
end

279
-- Convert list of string names to domain names
280
function policy.todnames(names)
281
	for i, v in ipairs(names) do
282
		names[i] = kres.str2dname(v)
283
	end
284
	return names
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
end

-- RFC1918 Private, local, broadcast, test and special zones 
local private_zones = {
	'10.in-addr.arpa.',
	'16.172.in-addr.arpa.',
	'17.172.in-addr.arpa.',
	'18.172.in-addr.arpa.',
	'19.172.in-addr.arpa.',
	'20.172.in-addr.arpa.',
	'21.172.in-addr.arpa.',
	'22.172.in-addr.arpa.',
	'23.172.in-addr.arpa.',
	'24.172.in-addr.arpa.',
	'25.172.in-addr.arpa.',
	'26.172.in-addr.arpa.',
	'27.172.in-addr.arpa.',
	'28.172.in-addr.arpa.',
	'29.172.in-addr.arpa.',
	'30.172.in-addr.arpa.',
	'31.172.in-addr.arpa.',
	'168.192.in-addr.arpa.',
	-- RFC5735, RFC5737
	'0.in-addr.arpa.',
	'127.in-addr.arpa.',
	'254.169.in-addr.arpa.',
	'2.0.192.in-addr.arpa.',
	'100.51.198.in-addr.arpa.',
	'113.0.203.in-addr.arpa.',
	'255.255.255.255.in-addr.arpa.',
	-- IPv6 local, example
	'0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.',
	'1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.',
	'd.f.ip6.arpa.',
	'8.e.f.ip6.arpa.',
	'9.e.f.ip6.arpa.',
	'a.e.f.ip6.arpa.',
	'b.e.f.ip6.arpa.',
	'8.b.d.0.1.0.0.2.ip6.arpa',
}
325
policy.todnames(private_zones)
326 327

-- @var Default rules
328
policy.rules = {}
329
policy.postrules = {}
330
policy.add(policy.suffix_common(policy.DENY, private_zones, '\4arpa\0'))
331

332
return policy