diff --git a/main.py b/main.py index 06daeba..33c1fdf 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,18 @@ import sqlite3 import os from flask import Flask, request, jsonify, render_template, Response +from flask_httpauth import HTTPBasicAuth +import configparser + +PREFIX='/data' +if os.environ.get('PREFIX') is not None: + PREFIX=os.environ.get('PREFIX') + app = Flask(__name__) +config = configparser.ConfigParser() +auth = HTTPBasicAuth() + +config.read(os.path.join(PREFIX, 'vpnunit.config.ini')) class Ex(Exception): def __init__(self, code, message): @@ -12,7 +23,8 @@ class Ex(Exception): def __str__(self): return self._message -DATABASE='/data/database.sqlite3' + +DATABASE=os.path.join(PREFIX, 'database.sqlite3') db = sqlite3.connect(DATABASE) cu = db.cursor() db.execute('SELECT name FROM sqlite_master') @@ -24,15 +36,21 @@ if len(cu.fetchall()) == 0: cu.close() db.close() -os.environ['EASYRSA_PKI'] = '/data/pki' +os.environ['EASYRSA_PKI'] = os.path.join(PREFIX, 'pki') os.environ['EASYRSA_BATCH'] = '1' os.environ['PATH'] = os.environ['PATH'] + ':' + os.getcwd() + '/easy-rsa/easyrsa3' +@auth.verify_password +def verify_user(username, password): + if username == config['DEFAULT']['username'] and password == config['DEFAULT']['password']: + return username + @app.route('/') def hello_world(): return 'It works!' @app.route('/users', methods=['GET']) +@auth.login_required def get_users(): db = sqlite3.connect(DATABASE) cu = db.cursor() @@ -46,6 +64,7 @@ def get_users(): return jsonify(users) @app.route('/users', methods=['POST']) +@auth.login_required def post_users(): db = sqlite3.connect(DATABASE) cu = db.cursor() @@ -64,6 +83,7 @@ def post_users(): db.close() @app.route('/gateways', methods=['GET']) +@auth.login_required def get_gateways(): db = sqlite3.connect(DATABASE) cu = db.cursor() @@ -77,6 +97,7 @@ def get_gateways(): return jsonify(gateways) @app.route('/gateways', methods=['POST']) +@auth.login_required def post_gateways(): db = sqlite3.connect(DATABASE) cu = db.cursor() @@ -117,10 +138,10 @@ def post_gateways(): network = '2001:470:c844:' + ipid + '0::/60' staticclient = render_template('staticclient', address=address, network=network) - with open('/data/ovpn/clients/' + name, 'w') as f: + with open(os.path.join(PREFIX, 'ovpn/clients/', name), 'w') as f: f.write(staticclient) - with open('/data/ip/routes', 'a') as f: + with open(os.path.join(PREFIX, 'ip/routes'), 'a') as f: f.write(network + ' via ' + address + '\n') db.commit() @@ -136,6 +157,7 @@ def post_gateways(): db.close() @app.route('/gateway/', methods=['GET']) +@auth.login_required def get_gateway(fqdn): db = sqlite3.connect(DATABASE) cu = db.cursor() @@ -153,6 +175,7 @@ def get_gateway(fqdn): return jsonify(gateway) @app.route('/gateway//config', methods=['GET']) +@auth.login_required def get_gateway_config(fqdn): # TODO sanity check FQDN # WARNING: maybe you want to do more than a simple sanity check, @@ -169,6 +192,7 @@ def get_gateway_config(fqdn): return Response(render_template('config.ovpn', ca=ca, cert=cert, key=key), mimetype='text/plain') @app.route('/gateway/', methods=['DELETE']) +@auth.login_required def delete_gateway(fqdn): # TODO sanity check for this parameter! Possible system command injection db = sqlite3.connect(DATABASE) @@ -182,7 +206,7 @@ def delete_gateway(fqdn): address = '2001:470:c844::' + ipid + '0/64' network = '2001:470:c844:' + ipid + '0::/60' - sedrm = "sed -i '\\_^" + network + " via " + address + "$_d' /data/ip/routes" + sedrm = "sed -i '\\_^" + network + " via " + address + "$_d' " + PREFIX + "/ip/routes" print('[sedrm] ' + sedrm) r = os.system(sedrm) if r != 0: @@ -200,7 +224,7 @@ def delete_gateway(fqdn): except Ex as e: return jsonify({'status': 'error', 'message': str(e)}), e.getCode() - os.remove('/data/ovpn/clients/' + fqdn) + os.remove(os.path.join(PREFIX, 'ovpn/clients', fqdn)) db.commit() cu.close()