ss / settings.py
wnum's picture
Upload 2 files
ddfa897 verified
import logging
import os.path
import ssl
import sys
from tornado.options import define
from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting
)
from webssh.utils import (
to_ip_address, parse_origin_from_url, is_valid_encoding
)
from webssh._version import __version__
def print_version(flag):
if flag:
print(__version__)
sys.exit(0)
define('address', default='', help='Listen address')
define('port', type=int, default=8888, help='Listen port')
define('ssladdress', default='', help='SSL listen address')
define('sslport', type=int, default=4433, help='SSL listen port')
define('certfile', default='', help='SSL certificate file')
define('keyfile', default='', help='SSL private key file')
define('debug', type=bool, default=False, help='Debug mode')
define('policy', default='warning',
help='Missing host key policy, reject|autoadd|warning')
define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='Trusted downstream, separated by comma')
define('redirect', type=bool, default=True, help='Redirecting http to https')
define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection')
define('origin', default='same', help='''Origin policy,
'same': same origin policy, matches host name and port number;
'primary': primary domain policy, matches primary domain only;
'<domains>': custom domains policy, matches any domain in the <domains> list
separated by comma;
'*': wildcard policy, matches any domain, allowed in debug mode only.''')
define('wpintvl', type=float, default=0, help='Websocket ping interval')
define('timeout', type=float, default=3, help='SSH connection timeout')
define('delay', type=float, default=3, help='The delay to call recycle_worker')
define('maxconn', type=int, default=20,
help='Maximum live connections (ssh sessions) per client')
define('font', default='', help='custom font filename')
define('encoding', default='utf-8',
help='''The default character encoding of ssh servers.
Example: --encoding='utf-8' to solve the problem with some switches&routers''')
define('version', type=bool, help='Show version information',
callback=print_version)
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
font_dirs = ['webssh', 'static', 'css', 'fonts']
max_body_size = 1 * 1024 * 1024
class Font(object):
def __init__(self, filename, dirs):
self.family = self.get_family(filename)
self.url = self.get_url(filename, dirs)
def get_family(self, filename):
return filename.split('.')[0]
def get_url(self, filename, dirs):
return '/'.join(dirs + [filename])
def get_app_settings(options):
settings = dict(
template_path=os.path.join(base_dir, 'webssh', 'templates'),
static_path=os.path.join(base_dir, 'webssh', 'static'),
websocket_ping_interval=options.wpintvl,
debug=options.debug,
xsrf_cookies=options.xsrf,
font=Font(
get_font_filename(options.font,
os.path.join(base_dir, *font_dirs)),
font_dirs[1:]
),
origin_policy=get_origin_setting(options)
)
return settings
def get_server_settings(options):
settings = dict(
xheaders=options.xheaders,
max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options.tdstream)
)
return settings
def get_host_keys_settings(options):
if not options.hostfile:
host_keys_filename = os.path.join(base_dir, 'known_hosts')
else:
host_keys_filename = options.hostfile
host_keys = load_host_keys(host_keys_filename)
if not options.syshostfile:
filename = os.path.expanduser('~/.ssh/known_hosts')
else:
filename = options.syshostfile
system_host_keys = load_host_keys(filename)
settings = dict(
host_keys=host_keys,
system_host_keys=system_host_keys,
host_keys_filename=host_keys_filename
)
return settings
def get_policy_setting(options, host_keys_settings):
policy_class = get_policy_class(options.policy)
logging.info(policy_class.__name__)
check_policy_setting(policy_class, host_keys_settings)
return policy_class()
def get_ssl_context(options):
if not options.certfile and not options.keyfile:
return None
elif not options.certfile:
raise ValueError('certfile is not provided')
elif not options.keyfile:
raise ValueError('keyfile is not provided')
elif not os.path.isfile(options.certfile):
raise ValueError('File {!r} does not exist'.format(options.certfile))
elif not os.path.isfile(options.keyfile):
raise ValueError('File {!r} does not exist'.format(options.keyfile))
else:
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(options.certfile, options.keyfile)
return ssl_ctx
def get_trusted_downstream(tdstream):
result = set()
for ip in tdstream.split(','):
ip = ip.strip()
if ip:
to_ip_address(ip)
result.add(ip)
return result
def get_origin_setting(options):
if options.origin == '*':
if not options.debug:
raise ValueError(
'Wildcard origin policy is only allowed in debug mode.'
)
else:
return '*'
origin = options.origin.lower()
if origin in ['same', 'primary']:
return origin
origins = set()
for url in origin.split(','):
orig = parse_origin_from_url(url)
if orig:
origins.add(orig)
if not origins:
raise ValueError('Empty origin list')
return origins
def get_font_filename(font, font_dir):
filenames = {f for f in os.listdir(font_dir) if not f.startswith('.')
and os.path.isfile(os.path.join(font_dir, f))}
if font:
if font not in filenames:
raise ValueError(
'Font file {!r} not found'.format(os.path.join(font_dir, font))
)
elif filenames:
font = filenames.pop()
return font
def check_encoding_setting(encoding):
if encoding and not is_valid_encoding(encoding):
raise ValueError('Unknown character encoding {!r}.'.format(encoding))