mirror of https://github.com/huashengdun/webssh
Added is_valid_hostname to utils
parent
f610020758
commit
37299468a9
|
@ -70,6 +70,14 @@ class TestApp(AsyncHTTPTestCase):
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'"status": "The port field is required"', response.body)
|
self.assertIn(b'"status": "The port field is required"', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=127.0.0&port=22&username=&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'"status": "Invalid hostname', response.body)
|
||||||
|
|
||||||
|
body = 'hostname=http://www.googe.com&port=22&username=&password'
|
||||||
|
response = self.fetch('/', method='POST', body=body)
|
||||||
|
self.assertIn(b'"status": "Invalid hostname', response.body)
|
||||||
|
|
||||||
body = 'hostname=127.0.0.1&port=port&username=&password'
|
body = 'hostname=127.0.0.1&port=port&username=&password'
|
||||||
response = self.fetch('/', method='POST', body=body)
|
response = self.fetch('/', method='POST', body=body)
|
||||||
self.assertIn(b'"status": "Invalid port', response.body)
|
self.assertIn(b'"status": "Invalid port', response.body)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
|
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
|
||||||
is_valid_port, to_str, to_bytes)
|
is_valid_port, is_valid_hostname, to_str, to_bytes)
|
||||||
|
|
||||||
|
|
||||||
class TestUitls(unittest.TestCase):
|
class TestUitls(unittest.TestCase):
|
||||||
|
@ -34,3 +34,14 @@ class TestUitls(unittest.TestCase):
|
||||||
self.assertTrue(is_valid_port(80))
|
self.assertTrue(is_valid_port(80))
|
||||||
self.assertFalse(is_valid_port(0))
|
self.assertFalse(is_valid_port(0))
|
||||||
self.assertFalse(is_valid_port(65536))
|
self.assertFalse(is_valid_port(65536))
|
||||||
|
|
||||||
|
def test_is_valid_hostname(self):
|
||||||
|
self.assertTrue(is_valid_hostname('google.com'))
|
||||||
|
self.assertTrue(is_valid_hostname('google.com.'))
|
||||||
|
self.assertTrue(is_valid_hostname('www.google.com'))
|
||||||
|
self.assertTrue(is_valid_hostname('www.google.com.'))
|
||||||
|
self.assertFalse(is_valid_hostname('.www.google.com'))
|
||||||
|
self.assertFalse(is_valid_hostname('http://www.google.com'))
|
||||||
|
self.assertFalse(is_valid_hostname('https://www.google.com'))
|
||||||
|
self.assertFalse(is_valid_hostname('127.0.0.1'))
|
||||||
|
self.assertFalse(is_valid_hostname('::1'))
|
||||||
|
|
|
@ -11,8 +11,10 @@ import tornado.web
|
||||||
|
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from webssh.worker import Worker, recycle_worker, workers
|
from webssh.worker import Worker, recycle_worker, workers
|
||||||
from webssh.utils import (is_valid_ipv4_address, is_valid_ipv6_address,
|
from webssh.utils import (
|
||||||
is_valid_port, to_bytes, to_str, UnicodeType)
|
is_valid_ipv4_address, is_valid_ipv6_address, is_valid_port,
|
||||||
|
is_valid_hostname, to_bytes, to_str, UnicodeType
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
@ -98,6 +100,13 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
'wrong password for decrypting the private key.')
|
'wrong password for decrypting the private key.')
|
||||||
return pkey
|
return pkey
|
||||||
|
|
||||||
|
def get_hostname(self):
|
||||||
|
value = self.get_value('hostname')
|
||||||
|
if not (is_valid_hostname(value) | is_valid_ipv4_address(value) |
|
||||||
|
is_valid_ipv6_address(value)):
|
||||||
|
raise ValueError('Invalid hostname {}'.format(value))
|
||||||
|
return value
|
||||||
|
|
||||||
def get_port(self):
|
def get_port(self):
|
||||||
value = self.get_value('port')
|
value = self.get_value('port')
|
||||||
try:
|
try:
|
||||||
|
@ -117,7 +126,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
hostname = self.get_value('hostname')
|
hostname = self.get_hostname()
|
||||||
port = self.get_port()
|
port = self.get_port()
|
||||||
username = self.get_value('username')
|
username = self.get_value('username')
|
||||||
password = self.get_argument('password')
|
password = self.get_argument('password')
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from types import UnicodeType
|
from types import UnicodeType
|
||||||
|
@ -38,3 +40,20 @@ def is_valid_ipv6_address(ipstr):
|
||||||
|
|
||||||
def is_valid_port(port):
|
def is_valid_port(port):
|
||||||
return 0 < port < 65536
|
return 0 < port < 65536
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_hostname(hostname):
|
||||||
|
if hostname[-1] == ".":
|
||||||
|
# strip exactly one dot from the right, if present
|
||||||
|
hostname = hostname[:-1]
|
||||||
|
if len(hostname) > 253:
|
||||||
|
return False
|
||||||
|
|
||||||
|
labels = hostname.split(".")
|
||||||
|
|
||||||
|
# the TLD must be not all-numeric
|
||||||
|
if re.match(r"[0-9]+$", labels[-1]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
allowed = re.compile(r"(?!-)[a-z0-9-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||||
|
return all(allowed.match(label) for label in labels)
|
||||||
|
|
Loading…
Reference in New Issue