Fixed test_print_version

pull/30/head
Sheng 2018-10-03 19:22:20 +08:00
parent fdcf1718c1
commit 90e7ea0327
1 changed files with 13 additions and 2 deletions

View File

@ -1,3 +1,5 @@
import io
import sys
import os.path import os.path
import unittest import unittest
import paramiko import paramiko
@ -8,15 +10,24 @@ from webssh.policy import load_host_keys
from webssh.settings import ( from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version get_host_keys_settings, get_policy_setting, base_dir, print_version
) )
from webssh.utils import UnicodeType
from webssh._version import __version__ from webssh._version import __version__
class TestSettings(unittest.TestCase): class TestSettings(unittest.TestCase):
def test_print_version(self): def test_print_version(self):
self.assertNotEqual(print_version(False), 2, msg=__version__) sys_stdout = sys.stdout
sys.stdout = io.StringIO() if UnicodeType == str else io.BytesIO()
self.assertEqual(print_version(False), None)
self.assertEqual(sys.stdout.getvalue(), '')
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
self.assertEqual(print_version(True), 2, msg=__version__) self.assertEqual(print_version(True), None)
self.assertEqual(sys.stdout.getvalue(), __version__ + '\n')
sys.stdout = sys_stdout
def test_get_host_keys_settings(self): def test_get_host_keys_settings(self):
options.hostFile = '' options.hostFile = ''