From b31f8d58678eb170eb7d463ceafbb76c57a2d641 Mon Sep 17 00:00:00 2001
From: ibuler <ibuler@qq.com>
Date: Wed, 31 Aug 2016 19:28:06 +0800
Subject: [PATCH] Add user generate reset password token

---
 apps/common/tasks.py            | 34 +++++++++++++++++++++++++
 apps/jumpserver/settings.py     | 10 ++++++++
 apps/users/models.py            | 45 ++++++++++++++++++++++++---------
 apps/users/tests/test_models.py |  8 ++++++
 apps/users/tests/test_views.py  |  8 ++++++
 apps/users/utils.py             |  1 +
 6 files changed, 94 insertions(+), 12 deletions(-)
 create mode 100644 apps/common/tasks.py

diff --git a/apps/common/tasks.py b/apps/common/tasks.py
new file mode 100644
index 000000000..8df504baa
--- /dev/null
+++ b/apps/common/tasks.py
@@ -0,0 +1,34 @@
+from __future__ import absolute_import
+
+from celery import shared_task
+from django.core.mail import send_mail
+from django.conf import settings
+
+
+@shared_task(name='send_mail_async')
+def send_mail_async(*args, **kwargs):
+    """ Using celery to send email async
+
+    You can use it as django send_mail function
+
+    Example:
+    send_mail_sync.delay(subject, message, from_mail, recipient_list, fail_silently=False, html_message=None)
+
+    Also you can ignore the from_mail, unlike django send_mail, from_email is not a require args:
+
+    Example:
+    send_mail_sync.delay(subject, message, recipient_list, fail_silently=False, html_message=None)
+    """
+    if len(args) == 3:
+        args = list(args)
+        args[0] = settings.EMAIL_SUBJECT_PREFIX + args[0]
+        args.insert(2, settings.EMAIL_HOST_USER)
+        args = tuple(args)
+
+    send_mail(*args, **kwargs)
+
+
+# def send_mail_async(subject, message, from_mail, recipient_list, fail_silently=False, html_message=None):
+#     if settings.CONFIG.MAIL_SUBJECT_PREFIX:
+#         subject += settings.CONFIG.MAIL_SUBJECT_PREFIX
+#     send_mail(subject, message, from_mail, recipient_list, fail_silently=fail_silently, html_message=html_message)
diff --git a/apps/jumpserver/settings.py b/apps/jumpserver/settings.py
index f741fbd59..edb47798b 100644
--- a/apps/jumpserver/settings.py
+++ b/apps/jumpserver/settings.py
@@ -236,6 +236,16 @@ BOOTSTRAP_COLUMN_COUNT = 11
 # Init data or generate fake data source for development
 FIXTURE_DIRS = [os.path.join(BASE_DIR, 'fixtures'), ]
 
+
+# Email config
+EMAIL_HOST = CONFIG.EMAIL_HOST
+EMAIL_PORT = CONFIG.EMAIL_PORT
+EMAIL_HOST_USER = CONFIG.EMAIL_HOST_USER
+EMAIL_HOST_PASSWORD = CONFIG.EMAIL_HOST_PASSWORD
+EMAIL_USE_SSL = CONFIG.EMAIL_USE_SSL
+EMAIL_USE_TLS = CONFIG.EMAIL_USE_TLS
+EMAIL_SUBJECT_PREFIX = CONFIG.EMAIL_SUBJECT_PREFIX
+
 REST_FRAMEWORK = {
     # Use Django's standard `django.contrib.auth` permissions,
     # or allow read-only access for unauthenticated users.
diff --git a/apps/users/models.py b/apps/users/models.py
index 84c34c48c..aead530ba 100644
--- a/apps/users/models.py
+++ b/apps/users/models.py
@@ -3,6 +3,7 @@
 from __future__ import unicode_literals
 
 import datetime
+
 from django.conf import settings
 from django.contrib.auth.hashers import make_password
 from django.utils import timezone
@@ -13,6 +14,7 @@ from django.dispatch import receiver
 from django.db import IntegrityError
 from rest_framework.authtoken.models import Token
 
+from django.core import signing
 
 # class Role(models.Model):
 #     name = models.CharField('name', max_length=80, unique=True)
@@ -113,8 +115,6 @@ class User(AbstractUser):
     private_key = models.CharField(max_length=5000, blank=True, verbose_name='ssh私钥')  # ssh key max length 4096 bit
     public_key = models.CharField(max_length=1000, blank=True, verbose_name='公钥')
     comment = models.TextField(max_length=200, blank=True, verbose_name='描述')
-    confirmed = models.BooleanField(default=False)
-    date_confirmed = models.DateField(blank=True, null=True, verbose_name='确认时间')
     date_expired = models.DateTimeField(default=date_expired_default, blank=True, null=True, verbose_name='有效期')
     created_by = models.CharField(max_length=30, default='')
 
@@ -177,22 +177,43 @@ class User(AbstractUser):
             # super(User, self).save(*args, **kwargs)
 
     @property
-    def token(self):
-        return self.get_token()
+    def private_token(self):
+        return self.get_private_token()
 
-    def get_token(self):
+    def get_private_token(self):
         try:
             token = Token.objects.get(user=self)
-            return token.key
         except Token.DoesNotExist:
-            return ''
+            token = Token.objects.create(user=self)
 
-    def set_token(self):
+        return token.key
+
+    def refresh_private_token(self):
+        Token.objects.filter(user=self).delete()
+        return Token.objects.create(user=self)
+
+    @classmethod
+    def generate_reset_token(cls, email):
         try:
-            return Token.objects.create(user=self)
-        except IntegrityError:
-            Token.objects.filter(user=self).delete()
-            return Token.objects.create(user=self)
+            user = cls.objects.get(email=email)
+            return signing.dumps({'reset': user.id, 'email': user.email})
+        except cls.DoesNotExist:
+            return None
+
+    @classmethod
+    def reset_password(cls, token, new_password, max_age=3600):
+        try:
+            data = signing.loads(token, max_age=max_age)
+            user_id = data.get('reset', None)
+            user_email = data.get('email', '')
+            user = cls.objects.get(id=user_id, email=user_email)
+            user.set_password(new_password)
+            user.save()
+            return True
+
+        except signing.BadSignature, cls.DoesNotExist:
+            pass
+        return False
 
     class Meta:
         db_table = 'user'
diff --git a/apps/users/tests/test_models.py b/apps/users/tests/test_models.py
index 81d3f5b85..e011aa209 100644
--- a/apps/users/tests/test_models.py
+++ b/apps/users/tests/test_models.py
@@ -75,6 +75,14 @@ class UserModelTest(TransactionTestCase):
         self.assertTrue(user.check_password(password))
         self.assertFalse(user.check_password(password*2))
 
+    def test_user_reset_password(self):
+        user = User.objects.first()
+        token = User.generate_reset_token(user.email)
+        new_password = gen_username()
+        User.reset_password(token, new_password)
+        user_ = User.objects.get(id=user.id)
+        self.assertTrue(user_.check_password(new_password))
+
     def tearDown(self):
         User.objects.all().delete()
         UserGroup.objects.all().delete()
diff --git a/apps/users/tests/test_views.py b/apps/users/tests/test_views.py
index 74af9e3b0..6552d03b7 100644
--- a/apps/users/tests/test_views.py
+++ b/apps/users/tests/test_views.py
@@ -12,6 +12,7 @@ from .base import gen_username, gen_name, gen_email, get_role
 class UserListViewTests(TransactionTestCase):
     def setUp(self):
         init_all_models()
+        self.client.login(username='admin', password='admin')
 
     def test_a_new_user_in_list(self):
         username = gen_username()
@@ -32,10 +33,14 @@ class UserListViewTests(TransactionTestCase):
         response = self.client.get(reverse('users:user-list'))
         self.assertEqual(response.context['is_paginated'], True)
 
+    def tearDown(self):
+        self.client.logout()
+
 
 class UserAddTests(TestCase):
     def setUp(self):
         init_all_models()
+        self.client.login(username='admin', password='admin')
 
     def test_add_a_new_user(self):
         username = gen_username()
@@ -56,3 +61,6 @@ class UserAddTests(TestCase):
         response = self.client.get(reverse('users:user-list'))
         self.assertContains(response, username)
 
+    def tearDown(self):
+        self.client.logout()
+
diff --git a/apps/users/utils.py b/apps/users/utils.py
index d6cee1080..8c5b0e32e 100644
--- a/apps/users/utils.py
+++ b/apps/users/utils.py
@@ -8,6 +8,7 @@ from paramiko.rsakey import RSAKey
 from django.contrib.auth.mixins import UserPassesTestMixin
 from django.urls import reverse_lazy
 
+
 try:
     import cStringIO as StringIO
 except ImportError: