Compare commits

...

11 Commits

  1. 15
      backend/decorators/__init__.py
  2. 87
      backend/decorators/cbv_decorator.py
  3. 51
      backend/decorators/exception_log.py
  4. 13
      backend/middlewares/__init__.py
  5. 48
      backend/middlewares/auth.py
  6. 13
      backend/routers/__init__.py
  7. 90
      backend/routers/account/models.py
  8. 19
      backend/routers/account/request_body.py
  9. 45
      backend/routers/account/views.py
  10. 53
      backend/utils/__init__.py
  11. 41
      backend/utils/gmssl_pack.py
  12. 82
      backend/utils/mysql.py
  13. 288
      backend/utils/redis_lib.py

15
backend/decorators/__init__.py

@ -0,0 +1,15 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: __init__.py
@create-time: 2023-09-25 15:42
@description: The new python script
"""
__all__ = [
"cbv",
"DecratorSet",
]
from .cbv_decorator import cbv
from .exception_log import DecratorSet

87
backend/decorators/cbv_decorator.py

@ -0,0 +1,87 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: cbv
@create-time: 2023-09-25 15:42
@description: The new python script
"""
import inspect
from typing import get_type_hints
from fastapi import APIRouter, Depends
from pydantic.typing import is_classvar
from starlette.routing import Route, WebSocketRoute
CBV_CLASS_KEY = "__cbv_class__"
def _update_cbv_route_endpoint_signature(cls, route):
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)
def _init_cbv(cls):
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover
return # Already initialized
old_init = cls.__init__
old_signature = inspect.signature(old_init)
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter
new_parameters = [
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]
dependency_names = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
continue
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)}
dependency_names.append(name)
new_parameters.append(
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs)
)
new_signature = old_signature.replace(parameters=new_parameters)
def new_init(self, *args, **kwargs):
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)
setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)
def _cbv(router, cls):
_init_cbv(cls)
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
functions_set = set(func for _, func in function_members)
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
return cls
def cbv(router):
def decorator(cls):
return _cbv(router, cls)
return decorator

51
backend/decorators/exception_log.py

@ -0,0 +1,51 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: exception_log
@create-time: 2023-09-26 09:55
@description: The new python script
"""
import traceback
from functools import wraps
class DecratorSet:
@staticmethod
def log_dec(func):
@wraps(func)
def inner(*args, **kwargs):
status = False
out = {"code": 400500, "msg": "程序内部未知错误"}
try:
out = func(*args, **kwargs)
status = True
except Exception as e:
status = False
err = repr(e)
msg_content = traceback.format_exc()
# msg_content_list = msg_content.split("\n")
# pattern = re.compile(r'^\s*File\s*\S+[\\/]models\.py(\S*\s*)')
# msg_list = [i for i in msg_content_list if pattern.match(i)]
# location = msg_list[-1].strip() if msg_list else "未知错误"
out = {"status": 500500, "data": None, "msg": msg_content}
finally:
if not status:
func_name = func.__name__
# redis.init_conn()
# redis.pub(
# "real_pub",
# str({
# "func": func_name,
# "err": err,
# "params": msg_param,
# "location": location,
# "details": msg_content
# })
# )
# redis.close()
return out
return inner

13
backend/middlewares/__init__.py

@ -0,0 +1,13 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: __init__.py
@create-time: 2023-09-06 17:44
@description: The new python script
"""
__all__ = [
"AuthorizationMiddleware"
]
from .auth import AuthorizationMiddleware

48
backend/middlewares/auth.py

@ -0,0 +1,48 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: auth
@create-time: 2023-09-27 15:13
@description: The new python script
"""
from starlette.datastructures import URL, Headers
from starlette.responses import JSONResponse
from utils import redis_db
URL_WHITE_LIST = {
"/account/login/",
"/account/get_user_salt/",
"/doc/",
"/redoc/",
"test"
}
class AuthorizationMiddleware:
def __init__(self, app):
self.app = app
def __call__(self, scope, receive, send):
url = URL(scope=scope)
path = url.path
if path in URL_WHITE_LIST:
return self.app(scope, receive, send)
headers = Headers(scope=scope)
token = headers.get("authorization") or headers.get("auth-token")
response = self.app
data = {"status": 400403, "data": None, "msg": "用户未登录禁止使用该功能"}
try:
assert token
data = {"status": 400403, "data": None, "msg": "该用户登录已过期或未登录禁止使用该功能"}
token_key = token.split("---")[0]
with redis_db as r:
assert r.exists(token_key)
data = {"status": 400403, "data": None, "msg": "该用户已在其他处登录,请重新登陆"}
cache = r.get(token_key)
assert cache == token_key
except AssertionError:
response = JSONResponse(data)
return response(scope, receive, send)

13
backend/routers/__init__.py

@ -0,0 +1,13 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: __init__.py
@create-time: 2023-09-06 17:45
@description: The new python script
"""
__all__ = [
"account_router",
]
from .account.views import account_router

90
backend/routers/account/models.py

@ -0,0 +1,90 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: models
@create-time: 2023-09-26 14:07
@description: The new python script
"""
import binascii
import os
from utils import mysql_db, gm_encrypt, redis_db, gm_decrypt
class LoginOperations:
_instance = None
def __new__(cls, *args, **kw):
if not cls._instance:
cls._instance = object.__new__(cls)
return cls._instance
def __init__(self, username, step, password=None):
self.__user = username
self.__step = step
self.__pwd = password
def get_user_salt(self):
sql = f"""
select salt
from tb_users
where username='{self.__user}' and
step='{self.__step}'
"""
row = mysql_db.get_one(sql)
if not row:
return {"status": 400404, "data": None, "msg": "未查询到该用户请查实后重新登录"}
salt = row[0]
salt_out_value = gm_encrypt.sm2_encrypt(salt)
return {
"status": 200, "data": salt_out_value, "msg": "Sucess!"
}
@property
def __generate_user_token(self):
return binascii.hexlify(os.urandom(32)).decode()
@staticmethod
def __save_token(key, token):
with redis_db as r:
r.set(key, token)
def login(self):
pwd = gm_decrypt.sm2_decrypt(self.__pwd)
sql = f"""
select is_admin
from tb_users
where username='{self.__user}' and
step='{self.__step}' and
concat(password)='{pwd}'
"""
row = mysql_db.get_one(sql)
if not row:
sql = f"""
select username
from tb_users
where username='{self.__user}' and
concat(password)='{pwd}'
"""
row = mysql_db.get_one(sql)
if not row:
return {"status": 400403, "data": None, "msg": "用户名或密码错误,请查实后重新登录"}
return {"status": 400200, "data": None, "msg": "该用户没有登录该岗位的权限,请重新选择正确岗位后登录"}
token = self.__generate_user_token
self.__save_token(f"{self.__user}_token", token)
return {
"status": 200,
"data": {
"user": self.__user,
"role": row[0],
"token": f"{self.__user}_token---{token}"
},
"msg": "Sucess!"
}
def user_logout(user):
with redis_db as r:
r.delete(f"{user}_token")
return {"status": 200, "data": None, "msg": "用户已退出登录"}

19
backend/routers/account/request_body.py

@ -0,0 +1,19 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: request_body
@create-time: 2023-09-26 14:54
@description: The new python script
"""
from pydantic import BaseModel
class UserInfo(BaseModel):
user: str
step: str
pwd: str
class Username(BaseModel):
user: str

45
backend/routers/account/views.py

@ -0,0 +1,45 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: account
@create-time: 2023-09-25 15:04
@description: The new python script
"""
from fastapi import APIRouter
from decorators import cbv
from decorators import DecratorSet
from .models import LoginOperations, user_logout
from .request_body import UserInfo, Username
account_router = APIRouter()
@cbv(account_router)
class LoginAPI:
@account_router.post("/login/", tags=["account"])
@DecratorSet.log_dec
def login(self, user_info: UserInfo):
user = user_info.user
step = user_info.step
pwd = user_info.pwd
out = LoginOperations(user, step, pwd).login()
return out
@account_router.get("/get_user_salt/", tags=["account"])
@DecratorSet.log_dec
def get_salt(self, user: str, step: str):
out = LoginOperations(user, step).get_user_salt()
return out
@cbv(account_router)
class LogoutAPI:
@account_router.post("/logout/", tags=["account"])
def logout(self, user_info: Username):
user = user_info.user
out = user_logout(user)
return out

53
backend/utils/__init__.py

@ -0,0 +1,53 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: __init__.py
@create-time: 2023-09-22 15:38
@description: The new python script
"""
__all__ = [
"keyb",
"GmEncryptDecrypt",
"MysqlPool",
"mysql_db",
"gm_encrypt",
"gm_decrypt",
"Redis",
"redis_db",
]
from .gmssl_pack import GmEncryptDecrypt
from .mysql import MysqlPool
from .redis_lib import Redis
keyb = {
"pkf": "",
"pkb": "",
"sk": ""
}
mysql_db = MysqlPool(
host="",
port=0,
user="",
pwd="",
db=""
)
gm_encrypt = GmEncryptDecrypt(
sm2_public_key=keyb.get("pkf")
)
gm_decrypt = GmEncryptDecrypt(
sm2_public_key=keyb.get("pkb"),
sm2_private_key=keyb.get("sk")
)
redis_db = Redis(
host="",
port=,
db="",
password=""
)

41
backend/utils/gmssl_pack.py

@ -0,0 +1,41 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: gmssl_pack
@create-time: 2023-09-21 09:22
@description: The new python script
"""
from gmssl.sm2 import CryptSM2
from gmssl.sm3 import sm3_hash
class GmEncryptDecrypt:
"""
国密sm4加解密
"""
def __init__(self, sm2_private_key=None, sm2_public_key=None, mode=1):
self.__sm2_endecrpt = CryptSM2(private_key=sm2_private_key, public_key=sm2_public_key, mode=mode)
def sm2_encrypt(self, text):
data = text.encode() if isinstance(text, str) else text
return self.__sm2_endecrpt.encrypt(data).hex()
def sm2_decrypt(self, msg):
data = bytes.fromhex(msg)
return self.__sm2_endecrpt.decrypt(data).decode()
@staticmethod
def sm3_encrypt(text):
return sm3_hash(bytearray(text, encoding="utf-8"))
if __name__ == '__main__':
# a = GmEncryptDecrypt.sm3_encrypt("qjx171024..ai7#n_;scy*$")
a = GmEncryptDecrypt(sm2_private_key="8b770da15fdc6d8d8ba9f0399c4fdc0a289d3a40f2871b286466da4b0a42a854",sm2_public_key=(
"f4424e38364741a011a4ec1c42727b5bc7e1f080719740d0adf1bfe14ceaedb1"
"161df14e75acda44af05ae71f2a9abd19a26cd64a6ac5a2fe665d8346a65b6ec"
)).sm2_decrypt("f1b3db07bbb0f2997d0f5798f5ef688afd70ee0267dd29c817950bf80e0bac7517aed2cad06e0346eb9d6348bee8fe4d1b73c71f9d0fabb85b48a85da301242f73d0a1fe29b7cbb6bbf97ddde4aeb1b8665d2c777e5052e3038495722d71e0f9caf00555ab585c57ff855c1a64601c3dd55f61067f006bc4338b15301cf7ce7b1ab160e65a1d19dd618b0f4bb17923c778b2e529713c90a105cee5a34c692243")
print(a)

82
backend/utils/mysql.py

@ -0,0 +1,82 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: mysql
@create-time: 2023-09-22 15:38
@description: The new python script
"""
import pandas as pd
import pymysql
from dbutils.pooled_db import PooledDB
class MysqlPool:
_instance = None
def __init__(self, host, port, user, pwd, db):
self.POOL = PooledDB(
creator=pymysql,
maxconnections=30, # 连接池的最大连接数
maxcached=10,
blocking=True,
setsession=[],
host=host,
port=port,
user=user,
password=pwd,
database=db,
charset='utf8',
)
def __new__(cls, *args, **kw):
if not cls._instance:
cls._instance = object.__new__(cls)
return cls._instance
def connect(self):
conn = self.POOL.connection()
cursor = conn.cursor()
return conn, cursor
@staticmethod
def connect_close(conn=None, cursor=None):
if cursor:
cursor.close()
if conn:
conn.close()
def insert_many(self, sql, data):
conn, cursor = self.connect()
cursor.executemany(sql, data)
conn.commit()
self.connect_close(conn, cursor)
def execute_sql(self, sql):
conn, cursor = self.connect()
cursor.execute(sql)
conn.commit()
self.connect_close(conn, cursor)
def get_one(self, sql):
conn, cursor = self.connect()
cursor.execute(sql)
row = cursor.fetchone()
self.connect_close(conn, cursor)
return row
def get_all(self, sql):
conn, cursor = self.connect()
cursor.execute(sql)
rows = cursor.fetchall()
self.connect_close(conn, cursor)
return rows
def get_df(self, sql, index_name=None):
conn, cursor = self.connect()
cursor.execute(sql)
rows = cursor.fetchall()
cols = [i[0] for i in cursor.description]
df = pd.DataFrame(rows, columns=cols)
self.connect_close(conn, cursor)
return df.set_index(index_name) if index_name else df

288
backend/utils/redis_lib.py

@ -0,0 +1,288 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: redis
@create-time: 2023-09-26 16:18
@description: The new python script
"""
from contextlib import AbstractContextManager
from redis import ConnectionPool, StrictRedis, ResponseError
class Redis(AbstractContextManager):
"""
Redis 统一连接以及常用命令工具包
"""
instance = None
def __init__(self, host=None, port=None, db=None, password=None, timeout=None):
"""
初始化 Redis 连接信息
:param host: Redis IP
:param port: Redis 端口
:param db: Redis 序号
:param password: Redis 密码
:param timeout: Redis 存储时 key 的默认过期时间
"""
self.__host = host
self.__port = port
self.__db = db
self.__password = password
self.__timeout = timeout
self.__conn = None
def init_conn(self, db=None, timeout=None):
"""
连接 redis
:param timeout: 连接 redis 超时时间
:param db: 在连接时可以指定另一个表但是慎用
"""
# 确保单次操作 redis 连接数只有一个。
# 需要经过前端联调测试是否合理
if not db:
db = self.__db
url = f"redis://:{self.__password}@{self.__host}:{self.__port}/{db}"
pool = ConnectionPool.from_url(url)
self.__conn = StrictRedis(connection_pool=pool, socket_connect_timeout=timeout)
def __enter__(self):
self.init_conn()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def set(self, key, value, timeout=None):
"""
redis 中插入一对 key-value
:param key: 插入的 key
:param value: 插入的 value
:param timeout: 插入的 key-value 的过期时间
"""
ex = timeout if timeout else self.__timeout
self.__conn.set(name=key, value=value, ex=ex)
def get(self, key):
"""
redis 中获取某个 key 的value
:param key: 要获取的 key
:return: redis 对应 key value
"""
result = self.__conn.get(key)
out = str(result, encoding="utf-8") if result else None
return out
def delete(self, *key):
"""
删除某对或者多对 key-value
:param key: 要删除的 key
"""
self.__conn.delete(*key)
def ttl(self, key):
"""
获取某个 key 当前剩余的过期时间
:param key: 要获取的 key
:return: 过期时间
"""
result = self.__conn.ttl(key)
out = int(result) if result else 0
return out
def lpush(self, key: str, *value: any) -> None:
"""
左边入队操作
@param key: 插入的key
@param value: 入队的值
"""
self.__conn.lpush(key, *value)
def rpush(self, key: str, *value: any) -> None:
"""
右边入队操作
@param key: 插入的key
@param value: 入队的值
"""
if value:
self.__conn.rpush(key, *value)
def lpop(self, key: str, count: int = 1) -> object:
"""
左边出队操作
@param key: 要出队的key
@param count: 出队数量
@return value: 出队元素
"""
result = self.__conn.lpop(key, count)
value = [v.decode() for v in result] if result else []
return value
def rpop(self, key: str, count: int = 1) -> object:
"""
右边出队操作
@param key: 要出队的key
@param count: 出队数量
@return value: 出队元素
"""
value = [v.decode() for v in self.__conn.rpop(key, count)]
return value
def rpop_lpush(self, key: str, backup_key: str) -> object:
"""
右边出队左边入队操作
@param key: 要出队的key
@param backup_key: 备份的key
@return value: 出队元素
"""
value = self.__conn.rpoplpush(key, backup_key)
return value
def llen(self, key: str) -> int:
"""
获取指定key队列的数量
@param key: 要查询的key
@return count: 指定key队列的数量
"""
return self.__conn.llen(key)
def lrange(self, key: str, start: int = 0, end: int = -1) -> list:
"""
获取指定key在指定下标范围内的队列值
@param key: 要查询的key
@param start: 队列起始下标
@param end: 队列终止下标
@return value: 指定下标范围内的队列值
"""
value = [v.decode() for v in self.__conn.lrange(key, start, end)]
return value
def ltrim(self, key: str, start: int = 1, end: int = 0) -> bool:
"""
删除指定key指定范围内的数据
@param key: 要删除的key
@param start: 队列起始下标
@param end: 队列终止下标
@return bool: 删除成功或失败
"""
return self.__conn.ltrim(key, start, end)
def lrem(self, key: str, value: str, count: int = 0) -> int:
"""
删除指定key指定元素
@param key: 要删除的key
@param value: 队列起始下标
@param count: 删除数量
如果大于0则从左往右删除队列中指定数量的指定元素
如果等于0则删除队列中全量指定元素
如果小于0则从右往左删除队列中指定数量的指定元素
@return num: 删除成功的数量
"""
return self.__conn.lrem(key, count, value)
def zadd(self, key, values):
"""
redis 中插入有序集合
:param key: 插入的 key
:param values: 这里的 value 是一个 dict
有序集合是根据元素的 score 来进行排序的
dict key 是有序集合中的元素
dict value 则是元素的 score
:return:
"""
self.__conn.zadd(key, values)
def get_zmenber(self, key, mins, maxs):
"""
获取有序集合中排名在某个范围的元素
当前该方法是为了 IP 黑名单设计的后续可以根据其他需求重构
:param key: 要获取的有序集合的 key
:param mins: score 下限
:param maxs: score 上限
:return: 该有序集合中的某个 score 范围内的元素集合
"""
inner_mins = mins - 60 * 60 * 24 * 7
self.__conn.zremrangebyscore(key, inner_mins, mins - 1)
result = self.__conn.zrangebyscore(key, mins, maxs)
out = {i.decode(encoding="utf-8") for i in result}
return out
def exists(self, key):
return self.__conn.exists(key)
def expire(self, key, timeout):
self.__conn.expire(key, timeout)
def json_set(self, key, path, *json_data, timeout=None, flag=None):
"""
该方法必须要搭配 RedisJSON 模块使用
该方法是向 redis 中添加或更新 json 数据
当前该方法是为了 log 日志设计
:param key: json key
:param path: json 的节点是个字符串类型
:param json_data: dict 类型是要添加进 redis 的值
:param timeout: key 的过期时间
:param flag:
:return:
"""
try:
if flag:
self.__conn.json().set(key, ".", {path[1:]: json_data[0]})
return
self.__conn.json().arrappend(key, path, *json_data)
except ResponseError:
try:
self.__conn.json().set(key, path, [json_data[0]])
except ResponseError:
self.__conn.json().set(key, ".", {path[1:]: [json_data[0]]})
if len(json_data) >= 2:
self.json_set(key, path, *json_data[1:])
finally:
if timeout:
self.__conn.expire(key, timeout)
def json_get(self, key, *path):
try:
objkeys = self.__conn.json().objkeys(key)
except Exception as _:
objkeys = []
if not objkeys:
objkeys = []
path = set(path) & set(objkeys)
result = self.__conn.json().get(key, *path)
return result
def close(self):
"""
关闭 redis 连接
:return:
"""
if self.__conn:
self.__conn.close()
def __del__(self):
self.close()
def pub(self, channel, message, **kwargs):
self.__conn.publish(channel, message, **kwargs)
def pubsub(self, **kwargs):
return self.__conn.pubsub(**kwargs)
Loading…
Cancel
Save