summaryrefslogtreecommitdiff
path: root/python/vyos/firewall.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/vyos/firewall.py')
-rw-r--r--python/vyos/firewall.py138
1 files changed, 138 insertions, 0 deletions
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py
index 355ec44b0..a61d0a9f8 100644
--- a/python/vyos/firewall.py
+++ b/python/vyos/firewall.py
@@ -14,11 +14,22 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import csv
+import gzip
+import os
import re
+from pathlib import Path
+from time import strftime
+
+from vyos.remote import download
+from vyos.template import is_ipv4
+from vyos.template import render
from vyos.util import call
from vyos.util import cmd
from vyos.util import dict_search_args
+from vyos.util import dict_search_recursive
+from vyos.util import run
# Functions for firewall group domain-groups
@@ -139,6 +150,9 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name):
if suffix[0] == '!':
suffix = f'!= {suffix[1:]}'
output.append(f'{ip_name} {prefix}addr {suffix}')
+
+ if dict_search_args(side_conf, 'geoip', 'country_code'):
+ output.append(f'{ip_name} {prefix}addr @GEOIP_CC_{fw_name}_{rule_id}')
if 'mac_address' in side_conf:
suffix = side_conf["mac_address"]
@@ -338,3 +352,127 @@ def parse_policy_set(set_conf, def_suffix):
mss = set_conf['tcp_mss']
out.append(f'tcp option maxseg size set {mss}')
return " ".join(out)
+
+# GeoIP
+
+nftables_geoip_conf = '/run/nftables-geoip.conf'
+geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz'
+geoip_lock_file = '/run/vyos-geoip.lock'
+
+def geoip_load_data(codes=[]):
+ data = None
+
+ if not os.path.exists(geoip_database):
+ return []
+
+ try:
+ with gzip.open(geoip_database, mode='rt') as csv_fh:
+ reader = csv.reader(csv_fh)
+ out = []
+ for start, end, code in reader:
+ if code.lower() in codes:
+ out.append([start, end, code.lower()])
+ return out
+ except:
+ print('Error: Failed to open GeoIP database')
+ return []
+
+def geoip_download_data():
+ url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m"))
+ try:
+ dirname = os.path.dirname(geoip_database)
+ if not os.path.exists(dirname):
+ os.mkdir(dirname)
+
+ download(geoip_database, url)
+ print("Downloaded GeoIP database")
+ return True
+ except:
+ print("Error: Failed to download GeoIP database")
+ return False
+
+class GeoIPLock(object):
+ def __init__(self, file):
+ self.file = file
+
+ def __enter__(self):
+ if os.path.exists(self.file):
+ return False
+
+ Path(self.file).touch()
+ return True
+
+ def __exit__(self, exc_type, exc_value, tb):
+ os.unlink(self.file)
+
+def geoip_update(firewall, force=False):
+ with GeoIPLock(geoip_lock_file) as lock:
+ if not lock:
+ print("Script is already running")
+ return False
+
+ if not firewall:
+ print("Firewall is not configured")
+ return True
+
+ if not os.path.exists(geoip_database):
+ if not geoip_download_data():
+ return False
+ elif force:
+ geoip_download_data()
+
+ ipv4_codes = {}
+ ipv6_codes = {}
+
+ ipv4_sets = {}
+ ipv6_sets = {}
+
+ # Map country codes to set names
+ for codes, path in dict_search_recursive(firewall, 'country_code'):
+ if path[0] == 'name':
+ set_name = f'GEOIP_CC_{path[1]}_{path[3]}'
+ ipv4_sets[set_name] = []
+ for code in codes:
+ if code not in ipv4_codes:
+ ipv4_codes[code] = [set_name]
+ else:
+ ipv4_codes[code].append(set_n)
+ elif path[0] == 'ipv6_name':
+ set_name = f'GEOIP_CC_{path[1]}_{path[3]}'
+ ipv6_sets[set_name] = []
+ for code in codes:
+ if code not in ipv6_codes:
+ ipv6_codes[code] = [set_name]
+ else:
+ ipv6_codes[code].append(set_name)
+
+ if not ipv4_codes and not ipv6_codes:
+ if force:
+ print("GeoIP not in use by firewall")
+ return True
+
+ geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes])
+
+ # Iterate IP blocks to assign to sets
+ for start, end, code in geoip_data:
+ ipv4 = is_ipv4(start)
+ if code in ipv4_codes and ipv4:
+ ip_range = f'{start}-{end}' if start != end else start
+ for setname in ipv4_codes[code]:
+ ipv4_sets[setname].append(ip_range)
+ if code in ipv6_codes and not ipv4:
+ ip_range = f'{start}-{end}' if start != end else start
+ for setname in ipv6_codes[code]:
+ ipv6_sets[setname].append(ip_range)
+
+ render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', {
+ 'ipv4_sets': ipv4_sets,
+ 'ipv6_sets': ipv6_sets
+ })
+
+ result = run(f'nft -f {nftables_geoip_conf}')
+ if result != 0:
+ print('Error: GeoIP failed to update firewall')
+ return False
+
+ return True