import logging import os.path import ssl import sys from tornado.options import define from webssh.policy import ( load_host_keys, get_policy_class, check_policy_setting ) from webssh.utils import ( to_ip_address, parse_origin_from_url, is_valid_encoding ) from webssh._version import __version__ def print_version(flag): if flag: print(__version__) sys.exit(0) define('address', default='', help='Listen address') define('port', type=int, default=8888, help='Listen port') define('ssladdress', default='', help='SSL listen address') define('sslport', type=int, default=4433, help='SSL listen port') define('certfile', default='', help='SSL certificate file') define('keyfile', default='', help='SSL private key file') define('debug', type=bool, default=False, help='Debug mode') define('policy', default='warning', help='Missing host key policy, reject|autoadd|warning') define('hostfile', default='', help='User defined host keys file') define('syshostfile', default='', help='System wide host keys file') define('tdstream', default='', help='Trusted downstream, separated by comma') define('redirect', type=bool, default=True, help='Redirecting http to https') define('fbidhttp', type=bool, default=True, help='Forbid public plain http incoming requests') define('xheaders', type=bool, default=True, help='Support xheaders') define('xsrf', type=bool, default=True, help='CSRF protection') define('origin', default='same', help='''Origin policy, 'same': same origin policy, matches host name and port number; 'primary': primary domain policy, matches primary domain only; '': custom domains policy, matches any domain in the list separated by comma; '*': wildcard policy, matches any domain, allowed in debug mode only.''') define('wpintvl', type=float, default=0, help='Websocket ping interval') define('timeout', type=float, default=3, help='SSH connection timeout') define('delay', type=float, default=3, help='The delay to call recycle_worker') define('maxconn', type=int, default=20, help='Maximum live connections (ssh sessions) per client') define('font', default='', help='custom font filename') define('encoding', default='utf-8', help='''The default character encoding of ssh servers. Example: --encoding='utf-8' to solve the problem with some switches&routers''') define('version', type=bool, help='Show version information', callback=print_version) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) font_dirs = ['webssh', 'static', 'css', 'fonts'] max_body_size = 1 * 1024 * 1024 class Font(object): def __init__(self, filename, dirs): self.family = self.get_family(filename) self.url = self.get_url(filename, dirs) def get_family(self, filename): return filename.split('.')[0] def get_url(self, filename, dirs): return '/'.join(dirs + [filename]) def get_app_settings(options): settings = dict( template_path=os.path.join(base_dir, 'webssh', 'templates'), static_path=os.path.join(base_dir, 'webssh', 'static'), websocket_ping_interval=options.wpintvl, debug=options.debug, xsrf_cookies=options.xsrf, font=Font( get_font_filename(options.font, os.path.join(base_dir, *font_dirs)), font_dirs[1:] ), origin_policy=get_origin_setting(options) ) return settings def get_server_settings(options): settings = dict( xheaders=options.xheaders, max_body_size=max_body_size, trusted_downstream=get_trusted_downstream(options.tdstream) ) return settings def get_host_keys_settings(options): if not options.hostfile: host_keys_filename = os.path.join(base_dir, 'known_hosts') else: host_keys_filename = options.hostfile host_keys = load_host_keys(host_keys_filename) if not options.syshostfile: filename = os.path.expanduser('~/.ssh/known_hosts') else: filename = options.syshostfile system_host_keys = load_host_keys(filename) settings = dict( host_keys=host_keys, system_host_keys=system_host_keys, host_keys_filename=host_keys_filename ) return settings def get_policy_setting(options, host_keys_settings): policy_class = get_policy_class(options.policy) logging.info(policy_class.__name__) check_policy_setting(policy_class, host_keys_settings) return policy_class() def get_ssl_context(options): if not options.certfile and not options.keyfile: return None elif not options.certfile: raise ValueError('certfile is not provided') elif not options.keyfile: raise ValueError('keyfile is not provided') elif not os.path.isfile(options.certfile): raise ValueError('File {!r} does not exist'.format(options.certfile)) elif not os.path.isfile(options.keyfile): raise ValueError('File {!r} does not exist'.format(options.keyfile)) else: ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_ctx.load_cert_chain(options.certfile, options.keyfile) return ssl_ctx def get_trusted_downstream(tdstream): result = set() for ip in tdstream.split(','): ip = ip.strip() if ip: to_ip_address(ip) result.add(ip) return result def get_origin_setting(options): if options.origin == '*': if not options.debug: raise ValueError( 'Wildcard origin policy is only allowed in debug mode.' ) else: return '*' origin = options.origin.lower() if origin in ['same', 'primary']: return origin origins = set() for url in origin.split(','): orig = parse_origin_from_url(url) if orig: origins.add(orig) if not origins: raise ValueError('Empty origin list') return origins def get_font_filename(font, font_dir): filenames = {f for f in os.listdir(font_dir) if not f.startswith('.') and os.path.isfile(os.path.join(font_dir, f))} if font: if font not in filenames: raise ValueError( 'Font file {!r} not found'.format(os.path.join(font_dir, font)) ) elif filenames: font = filenames.pop() return font def check_encoding_setting(encoding): if encoding and not is_valid_encoding(encoding): raise ValueError('Unknown character encoding {!r}.'.format(encoding))