Compare commits
11 Commits
73c4a54fca
...
18d5b0e9c1
Author | SHA1 | Date |
---|---|---|
Jingxun | 18d5b0e9c1 | 8 months ago |
Jingxun | aa9c970575 | 8 months ago |
Jingxun | 9932c5fb07 | 8 months ago |
Jingxun | 07ad66ff26 | 8 months ago |
Jingxun | 8e7917d2e3 | 8 months ago |
Jingxun | e67d8e7ded | 8 months ago |
Jingxun | a697e25873 | 8 months ago |
Jingxun | 4036b52867 | 8 months ago |
Jingxun | 2bc5066260 | 8 months ago |
Jingxun | 8d70cc3792 | 8 months ago |
Jingxun | e091ae73ba | 8 months ago |
13 changed files with 845 additions and 0 deletions
@ -0,0 +1,15 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: __init__.py |
|||
@create-time: 2023-09-25 15:42 |
|||
@description: The new python script |
|||
""" |
|||
__all__ = [ |
|||
"cbv", |
|||
"DecratorSet", |
|||
] |
|||
|
|||
from .cbv_decorator import cbv |
|||
from .exception_log import DecratorSet |
@ -0,0 +1,87 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: cbv |
|||
@create-time: 2023-09-25 15:42 |
|||
@description: The new python script |
|||
""" |
|||
|
|||
import inspect |
|||
from typing import get_type_hints |
|||
|
|||
from fastapi import APIRouter, Depends |
|||
from pydantic.typing import is_classvar |
|||
from starlette.routing import Route, WebSocketRoute |
|||
|
|||
CBV_CLASS_KEY = "__cbv_class__" |
|||
|
|||
|
|||
def _update_cbv_route_endpoint_signature(cls, route): |
|||
old_endpoint = route.endpoint |
|||
old_signature = inspect.signature(old_endpoint) |
|||
old_parameters = list(old_signature.parameters.values()) |
|||
old_first_parameter = old_parameters[0] |
|||
new_first_parameter = old_first_parameter.replace(default=Depends(cls)) |
|||
new_parameters = [new_first_parameter] + [ |
|||
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:] |
|||
] |
|||
new_signature = old_signature.replace(parameters=new_parameters) |
|||
setattr(route.endpoint, "__signature__", new_signature) |
|||
|
|||
|
|||
def _init_cbv(cls): |
|||
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover |
|||
return # Already initialized |
|||
old_init = cls.__init__ |
|||
old_signature = inspect.signature(old_init) |
|||
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter |
|||
new_parameters = [ |
|||
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) |
|||
] |
|||
dependency_names = [] |
|||
for name, hint in get_type_hints(cls).items(): |
|||
if is_classvar(hint): |
|||
continue |
|||
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} |
|||
dependency_names.append(name) |
|||
new_parameters.append( |
|||
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs) |
|||
) |
|||
new_signature = old_signature.replace(parameters=new_parameters) |
|||
|
|||
def new_init(self, *args, **kwargs): |
|||
for dep_name in dependency_names: |
|||
dep_value = kwargs.pop(dep_name) |
|||
setattr(self, dep_name, dep_value) |
|||
old_init(self, *args, **kwargs) |
|||
|
|||
setattr(cls, "__signature__", new_signature) |
|||
setattr(cls, "__init__", new_init) |
|||
setattr(cls, CBV_CLASS_KEY, True) |
|||
|
|||
|
|||
def _cbv(router, cls): |
|||
_init_cbv(cls) |
|||
cbv_router = APIRouter() |
|||
function_members = inspect.getmembers(cls, inspect.isfunction) |
|||
functions_set = set(func for _, func in function_members) |
|||
cbv_routes = [ |
|||
route |
|||
for route in router.routes |
|||
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set |
|||
] |
|||
for route in cbv_routes: |
|||
router.routes.remove(route) |
|||
_update_cbv_route_endpoint_signature(cls, route) |
|||
cbv_router.routes.append(route) |
|||
router.include_router(cbv_router) |
|||
return cls |
|||
|
|||
|
|||
def cbv(router): |
|||
|
|||
def decorator(cls): |
|||
return _cbv(router, cls) |
|||
|
|||
return decorator |
@ -0,0 +1,51 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: exception_log |
|||
@create-time: 2023-09-26 09:55 |
|||
@description: The new python script |
|||
""" |
|||
|
|||
import traceback |
|||
from functools import wraps |
|||
|
|||
|
|||
class DecratorSet: |
|||
|
|||
@staticmethod |
|||
def log_dec(func): |
|||
@wraps(func) |
|||
def inner(*args, **kwargs): |
|||
status = False |
|||
out = {"code": 400500, "msg": "程序内部未知错误"} |
|||
try: |
|||
out = func(*args, **kwargs) |
|||
status = True |
|||
except Exception as e: |
|||
status = False |
|||
err = repr(e) |
|||
msg_content = traceback.format_exc() |
|||
# msg_content_list = msg_content.split("\n") |
|||
# pattern = re.compile(r'^\s*File\s*\S+[\\/]models\.py(\S*\s*)') |
|||
# msg_list = [i for i in msg_content_list if pattern.match(i)] |
|||
# location = msg_list[-1].strip() if msg_list else "未知错误" |
|||
out = {"status": 500500, "data": None, "msg": msg_content} |
|||
finally: |
|||
if not status: |
|||
func_name = func.__name__ |
|||
# redis.init_conn() |
|||
# redis.pub( |
|||
# "real_pub", |
|||
# str({ |
|||
# "func": func_name, |
|||
# "err": err, |
|||
# "params": msg_param, |
|||
# "location": location, |
|||
# "details": msg_content |
|||
# }) |
|||
# ) |
|||
# redis.close() |
|||
return out |
|||
|
|||
return inner |
@ -0,0 +1,13 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: __init__.py |
|||
@create-time: 2023-09-06 17:44 |
|||
@description: The new python script |
|||
""" |
|||
__all__ = [ |
|||
"AuthorizationMiddleware" |
|||
] |
|||
|
|||
from .auth import AuthorizationMiddleware |
@ -0,0 +1,48 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: auth |
|||
@create-time: 2023-09-27 15:13 |
|||
@description: The new python script |
|||
""" |
|||
from starlette.datastructures import URL, Headers |
|||
from starlette.responses import JSONResponse |
|||
|
|||
from utils import redis_db |
|||
|
|||
URL_WHITE_LIST = { |
|||
"/account/login/", |
|||
"/account/get_user_salt/", |
|||
"/doc/", |
|||
"/redoc/", |
|||
"test" |
|||
} |
|||
|
|||
|
|||
class AuthorizationMiddleware: |
|||
def __init__(self, app): |
|||
self.app = app |
|||
|
|||
def __call__(self, scope, receive, send): |
|||
url = URL(scope=scope) |
|||
path = url.path |
|||
if path in URL_WHITE_LIST: |
|||
return self.app(scope, receive, send) |
|||
|
|||
headers = Headers(scope=scope) |
|||
token = headers.get("authorization") or headers.get("auth-token") |
|||
response = self.app |
|||
data = {"status": 400403, "data": None, "msg": "用户未登录禁止使用该功能"} |
|||
try: |
|||
assert token |
|||
data = {"status": 400403, "data": None, "msg": "该用户登录已过期或未登录禁止使用该功能"} |
|||
token_key = token.split("---")[0] |
|||
with redis_db as r: |
|||
assert r.exists(token_key) |
|||
data = {"status": 400403, "data": None, "msg": "该用户已在其他处登录,请重新登陆"} |
|||
cache = r.get(token_key) |
|||
assert cache == token_key |
|||
except AssertionError: |
|||
response = JSONResponse(data) |
|||
return response(scope, receive, send) |
@ -0,0 +1,13 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: __init__.py |
|||
@create-time: 2023-09-06 17:45 |
|||
@description: The new python script |
|||
""" |
|||
__all__ = [ |
|||
"account_router", |
|||
] |
|||
|
|||
from .account.views import account_router |
@ -0,0 +1,90 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: models |
|||
@create-time: 2023-09-26 14:07 |
|||
@description: The new python script |
|||
""" |
|||
import binascii |
|||
import os |
|||
|
|||
from utils import mysql_db, gm_encrypt, redis_db, gm_decrypt |
|||
|
|||
|
|||
class LoginOperations: |
|||
_instance = None |
|||
|
|||
def __new__(cls, *args, **kw): |
|||
if not cls._instance: |
|||
cls._instance = object.__new__(cls) |
|||
return cls._instance |
|||
|
|||
def __init__(self, username, step, password=None): |
|||
self.__user = username |
|||
self.__step = step |
|||
self.__pwd = password |
|||
|
|||
def get_user_salt(self): |
|||
sql = f""" |
|||
select salt |
|||
from tb_users |
|||
where username='{self.__user}' and |
|||
step='{self.__step}' |
|||
""" |
|||
row = mysql_db.get_one(sql) |
|||
if not row: |
|||
return {"status": 400404, "data": None, "msg": "未查询到该用户请查实后重新登录"} |
|||
salt = row[0] |
|||
salt_out_value = gm_encrypt.sm2_encrypt(salt) |
|||
return { |
|||
"status": 200, "data": salt_out_value, "msg": "Sucess!" |
|||
} |
|||
|
|||
@property |
|||
def __generate_user_token(self): |
|||
return binascii.hexlify(os.urandom(32)).decode() |
|||
|
|||
@staticmethod |
|||
def __save_token(key, token): |
|||
with redis_db as r: |
|||
r.set(key, token) |
|||
|
|||
def login(self): |
|||
pwd = gm_decrypt.sm2_decrypt(self.__pwd) |
|||
sql = f""" |
|||
select is_admin |
|||
from tb_users |
|||
where username='{self.__user}' and |
|||
step='{self.__step}' and |
|||
concat(password)='{pwd}' |
|||
""" |
|||
row = mysql_db.get_one(sql) |
|||
if not row: |
|||
sql = f""" |
|||
select username |
|||
from tb_users |
|||
where username='{self.__user}' and |
|||
concat(password)='{pwd}' |
|||
""" |
|||
row = mysql_db.get_one(sql) |
|||
if not row: |
|||
return {"status": 400403, "data": None, "msg": "用户名或密码错误,请查实后重新登录"} |
|||
return {"status": 400200, "data": None, "msg": "该用户没有登录该岗位的权限,请重新选择正确岗位后登录"} |
|||
token = self.__generate_user_token |
|||
self.__save_token(f"{self.__user}_token", token) |
|||
return { |
|||
"status": 200, |
|||
"data": { |
|||
"user": self.__user, |
|||
"role": row[0], |
|||
"token": f"{self.__user}_token---{token}" |
|||
}, |
|||
"msg": "Sucess!" |
|||
} |
|||
|
|||
|
|||
def user_logout(user): |
|||
with redis_db as r: |
|||
r.delete(f"{user}_token") |
|||
return {"status": 200, "data": None, "msg": "用户已退出登录"} |
@ -0,0 +1,19 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: request_body |
|||
@create-time: 2023-09-26 14:54 |
|||
@description: The new python script |
|||
""" |
|||
from pydantic import BaseModel |
|||
|
|||
|
|||
class UserInfo(BaseModel): |
|||
user: str |
|||
step: str |
|||
pwd: str |
|||
|
|||
|
|||
class Username(BaseModel): |
|||
user: str |
@ -0,0 +1,45 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: account |
|||
@create-time: 2023-09-25 15:04 |
|||
@description: The new python script |
|||
""" |
|||
|
|||
from fastapi import APIRouter |
|||
|
|||
from decorators import cbv |
|||
from decorators import DecratorSet |
|||
from .models import LoginOperations, user_logout |
|||
from .request_body import UserInfo, Username |
|||
|
|||
account_router = APIRouter() |
|||
|
|||
|
|||
@cbv(account_router) |
|||
class LoginAPI: |
|||
|
|||
@account_router.post("/login/", tags=["account"]) |
|||
@DecratorSet.log_dec |
|||
def login(self, user_info: UserInfo): |
|||
user = user_info.user |
|||
step = user_info.step |
|||
pwd = user_info.pwd |
|||
out = LoginOperations(user, step, pwd).login() |
|||
return out |
|||
|
|||
@account_router.get("/get_user_salt/", tags=["account"]) |
|||
@DecratorSet.log_dec |
|||
def get_salt(self, user: str, step: str): |
|||
out = LoginOperations(user, step).get_user_salt() |
|||
return out |
|||
|
|||
|
|||
@cbv(account_router) |
|||
class LogoutAPI: |
|||
@account_router.post("/logout/", tags=["account"]) |
|||
def logout(self, user_info: Username): |
|||
user = user_info.user |
|||
out = user_logout(user) |
|||
return out |
@ -0,0 +1,53 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: __init__.py |
|||
@create-time: 2023-09-22 15:38 |
|||
@description: The new python script |
|||
""" |
|||
|
|||
__all__ = [ |
|||
"keyb", |
|||
"GmEncryptDecrypt", |
|||
"MysqlPool", |
|||
"mysql_db", |
|||
"gm_encrypt", |
|||
"gm_decrypt", |
|||
"Redis", |
|||
"redis_db", |
|||
] |
|||
|
|||
from .gmssl_pack import GmEncryptDecrypt |
|||
from .mysql import MysqlPool |
|||
from .redis_lib import Redis |
|||
|
|||
keyb = { |
|||
"pkf": "", |
|||
"pkb": "", |
|||
"sk": "" |
|||
} |
|||
|
|||
mysql_db = MysqlPool( |
|||
host="", |
|||
port=0, |
|||
user="", |
|||
pwd="", |
|||
db="" |
|||
) |
|||
|
|||
gm_encrypt = GmEncryptDecrypt( |
|||
sm2_public_key=keyb.get("pkf") |
|||
) |
|||
|
|||
gm_decrypt = GmEncryptDecrypt( |
|||
sm2_public_key=keyb.get("pkb"), |
|||
sm2_private_key=keyb.get("sk") |
|||
) |
|||
|
|||
redis_db = Redis( |
|||
host="", |
|||
port=, |
|||
db="", |
|||
password="" |
|||
) |
@ -0,0 +1,41 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: gmssl_pack |
|||
@create-time: 2023-09-21 09:22 |
|||
@description: The new python script |
|||
""" |
|||
|
|||
from gmssl.sm2 import CryptSM2 |
|||
from gmssl.sm3 import sm3_hash |
|||
|
|||
|
|||
class GmEncryptDecrypt: |
|||
""" |
|||
国密sm4加解密 |
|||
""" |
|||
|
|||
def __init__(self, sm2_private_key=None, sm2_public_key=None, mode=1): |
|||
self.__sm2_endecrpt = CryptSM2(private_key=sm2_private_key, public_key=sm2_public_key, mode=mode) |
|||
|
|||
def sm2_encrypt(self, text): |
|||
data = text.encode() if isinstance(text, str) else text |
|||
return self.__sm2_endecrpt.encrypt(data).hex() |
|||
|
|||
def sm2_decrypt(self, msg): |
|||
data = bytes.fromhex(msg) |
|||
return self.__sm2_endecrpt.decrypt(data).decode() |
|||
|
|||
@staticmethod |
|||
def sm3_encrypt(text): |
|||
return sm3_hash(bytearray(text, encoding="utf-8")) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
# a = GmEncryptDecrypt.sm3_encrypt("qjx171024..ai7#n_;scy*$") |
|||
a = GmEncryptDecrypt(sm2_private_key="8b770da15fdc6d8d8ba9f0399c4fdc0a289d3a40f2871b286466da4b0a42a854",sm2_public_key=( |
|||
"f4424e38364741a011a4ec1c42727b5bc7e1f080719740d0adf1bfe14ceaedb1" |
|||
"161df14e75acda44af05ae71f2a9abd19a26cd64a6ac5a2fe665d8346a65b6ec" |
|||
)).sm2_decrypt("f1b3db07bbb0f2997d0f5798f5ef688afd70ee0267dd29c817950bf80e0bac7517aed2cad06e0346eb9d6348bee8fe4d1b73c71f9d0fabb85b48a85da301242f73d0a1fe29b7cbb6bbf97ddde4aeb1b8665d2c777e5052e3038495722d71e0f9caf00555ab585c57ff855c1a64601c3dd55f61067f006bc4338b15301cf7ce7b1ab160e65a1d19dd618b0f4bb17923c778b2e529713c90a105cee5a34c692243") |
|||
print(a) |
@ -0,0 +1,82 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: mysql |
|||
@create-time: 2023-09-22 15:38 |
|||
@description: The new python script |
|||
""" |
|||
import pandas as pd |
|||
import pymysql |
|||
from dbutils.pooled_db import PooledDB |
|||
|
|||
|
|||
class MysqlPool: |
|||
_instance = None |
|||
|
|||
def __init__(self, host, port, user, pwd, db): |
|||
self.POOL = PooledDB( |
|||
creator=pymysql, |
|||
maxconnections=30, # 连接池的最大连接数 |
|||
maxcached=10, |
|||
blocking=True, |
|||
setsession=[], |
|||
host=host, |
|||
port=port, |
|||
user=user, |
|||
password=pwd, |
|||
database=db, |
|||
charset='utf8', |
|||
) |
|||
|
|||
def __new__(cls, *args, **kw): |
|||
if not cls._instance: |
|||
cls._instance = object.__new__(cls) |
|||
return cls._instance |
|||
|
|||
def connect(self): |
|||
conn = self.POOL.connection() |
|||
cursor = conn.cursor() |
|||
return conn, cursor |
|||
|
|||
@staticmethod |
|||
def connect_close(conn=None, cursor=None): |
|||
if cursor: |
|||
cursor.close() |
|||
if conn: |
|||
conn.close() |
|||
|
|||
def insert_many(self, sql, data): |
|||
conn, cursor = self.connect() |
|||
cursor.executemany(sql, data) |
|||
conn.commit() |
|||
self.connect_close(conn, cursor) |
|||
|
|||
def execute_sql(self, sql): |
|||
conn, cursor = self.connect() |
|||
cursor.execute(sql) |
|||
conn.commit() |
|||
self.connect_close(conn, cursor) |
|||
|
|||
def get_one(self, sql): |
|||
conn, cursor = self.connect() |
|||
cursor.execute(sql) |
|||
row = cursor.fetchone() |
|||
self.connect_close(conn, cursor) |
|||
return row |
|||
|
|||
def get_all(self, sql): |
|||
conn, cursor = self.connect() |
|||
cursor.execute(sql) |
|||
rows = cursor.fetchall() |
|||
self.connect_close(conn, cursor) |
|||
return rows |
|||
|
|||
def get_df(self, sql, index_name=None): |
|||
conn, cursor = self.connect() |
|||
cursor.execute(sql) |
|||
rows = cursor.fetchall() |
|||
cols = [i[0] for i in cursor.description] |
|||
df = pd.DataFrame(rows, columns=cols) |
|||
self.connect_close(conn, cursor) |
|||
return df.set_index(index_name) if index_name else df |
@ -0,0 +1,288 @@ |
|||
# encoding: utf-8 |
|||
""" |
|||
@author: Qiancj |
|||
@contact: qiancj@risenenergy.com |
|||
@file: redis |
|||
@create-time: 2023-09-26 16:18 |
|||
@description: The new python script |
|||
""" |
|||
from contextlib import AbstractContextManager |
|||
|
|||
from redis import ConnectionPool, StrictRedis, ResponseError |
|||
|
|||
|
|||
class Redis(AbstractContextManager): |
|||
""" |
|||
Redis 统一连接以及常用命令工具包 |
|||
""" |
|||
instance = None |
|||
|
|||
def __init__(self, host=None, port=None, db=None, password=None, timeout=None): |
|||
""" |
|||
初始化 Redis 连接信息 |
|||
:param host: Redis IP |
|||
:param port: Redis 端口 |
|||
:param db: Redis 序号 |
|||
:param password: Redis 密码 |
|||
:param timeout: Redis 存储时 key 的默认过期时间 |
|||
""" |
|||
self.__host = host |
|||
self.__port = port |
|||
self.__db = db |
|||
self.__password = password |
|||
self.__timeout = timeout |
|||
self.__conn = None |
|||
|
|||
def init_conn(self, db=None, timeout=None): |
|||
""" |
|||
连接 redis |
|||
:param timeout: 连接 redis 超时时间 |
|||
:param db: 在连接时可以指定另一个表,但是慎用 |
|||
""" |
|||
# 确保单次操作 redis 连接数只有一个。 |
|||
# 需要经过前端联调测试是否合理 |
|||
if not db: |
|||
db = self.__db |
|||
url = f"redis://:{self.__password}@{self.__host}:{self.__port}/{db}" |
|||
pool = ConnectionPool.from_url(url) |
|||
self.__conn = StrictRedis(connection_pool=pool, socket_connect_timeout=timeout) |
|||
|
|||
def __enter__(self): |
|||
self.init_conn() |
|||
return self |
|||
|
|||
def __exit__(self, exc_type, exc_val, exc_tb): |
|||
self.close() |
|||
|
|||
def set(self, key, value, timeout=None): |
|||
""" |
|||
向 redis 中插入一对 key-value |
|||
:param key: 插入的 key |
|||
:param value: 插入的 value |
|||
:param timeout: 插入的 key-value 的过期时间 |
|||
""" |
|||
ex = timeout if timeout else self.__timeout |
|||
self.__conn.set(name=key, value=value, ex=ex) |
|||
|
|||
def get(self, key): |
|||
""" |
|||
从 redis 中获取某个 key 的value |
|||
:param key: 要获取的 key |
|||
:return: redis 对应 key 的 value |
|||
""" |
|||
result = self.__conn.get(key) |
|||
out = str(result, encoding="utf-8") if result else None |
|||
return out |
|||
|
|||
def delete(self, *key): |
|||
""" |
|||
删除某对或者多对 key-value |
|||
:param key: 要删除的 key |
|||
""" |
|||
self.__conn.delete(*key) |
|||
|
|||
def ttl(self, key): |
|||
""" |
|||
获取某个 key 当前剩余的过期时间 |
|||
:param key: 要获取的 key |
|||
:return: 过期时间 |
|||
""" |
|||
result = self.__conn.ttl(key) |
|||
out = int(result) if result else 0 |
|||
return out |
|||
|
|||
def lpush(self, key: str, *value: any) -> None: |
|||
""" |
|||
左边入队操作 |
|||
|
|||
@param key: 插入的key |
|||
@param value: 入队的值 |
|||
""" |
|||
|
|||
self.__conn.lpush(key, *value) |
|||
|
|||
def rpush(self, key: str, *value: any) -> None: |
|||
""" |
|||
右边入队操作 |
|||
|
|||
@param key: 插入的key |
|||
@param value: 入队的值 |
|||
""" |
|||
if value: |
|||
self.__conn.rpush(key, *value) |
|||
|
|||
def lpop(self, key: str, count: int = 1) -> object: |
|||
""" |
|||
左边出队操作 |
|||
|
|||
@param key: 要出队的key |
|||
@param count: 出队数量 |
|||
@return value: 出队元素 |
|||
""" |
|||
|
|||
result = self.__conn.lpop(key, count) |
|||
value = [v.decode() for v in result] if result else [] |
|||
return value |
|||
|
|||
def rpop(self, key: str, count: int = 1) -> object: |
|||
""" |
|||
右边出队操作 |
|||
|
|||
@param key: 要出队的key |
|||
@param count: 出队数量 |
|||
@return value: 出队元素 |
|||
""" |
|||
|
|||
value = [v.decode() for v in self.__conn.rpop(key, count)] |
|||
return value |
|||
|
|||
def rpop_lpush(self, key: str, backup_key: str) -> object: |
|||
""" |
|||
右边出队左边入队操作 |
|||
|
|||
@param key: 要出队的key |
|||
@param backup_key: 备份的key |
|||
@return value: 出队元素 |
|||
""" |
|||
|
|||
value = self.__conn.rpoplpush(key, backup_key) |
|||
return value |
|||
|
|||
def llen(self, key: str) -> int: |
|||
""" |
|||
获取指定key队列的数量 |
|||
|
|||
@param key: 要查询的key |
|||
@return count: 指定key队列的数量 |
|||
""" |
|||
|
|||
return self.__conn.llen(key) |
|||
|
|||
def lrange(self, key: str, start: int = 0, end: int = -1) -> list: |
|||
""" |
|||
获取指定key在指定下标范围内的队列值 |
|||
|
|||
@param key: 要查询的key |
|||
@param start: 队列起始下标 |
|||
@param end: 队列终止下标 |
|||
@return value: 指定下标范围内的队列值 |
|||
""" |
|||
|
|||
value = [v.decode() for v in self.__conn.lrange(key, start, end)] |
|||
return value |
|||
|
|||
def ltrim(self, key: str, start: int = 1, end: int = 0) -> bool: |
|||
""" |
|||
删除指定key指定范围内的数据 |
|||
|
|||
@param key: 要删除的key |
|||
@param start: 队列起始下标 |
|||
@param end: 队列终止下标 |
|||
@return bool: 删除成功或失败 |
|||
""" |
|||
|
|||
return self.__conn.ltrim(key, start, end) |
|||
|
|||
def lrem(self, key: str, value: str, count: int = 0) -> int: |
|||
""" |
|||
删除指定key指定元素 |
|||
|
|||
@param key: 要删除的key |
|||
@param value: 队列起始下标 |
|||
@param count: 删除数量 |
|||
如果大于0,则从左往右删除队列中指定数量的指定元素 |
|||
如果等于0,则删除队列中全量指定元素 |
|||
如果小于0,则从右往左删除队列中指定数量的指定元素 |
|||
@return num: 删除成功的数量 |
|||
""" |
|||
|
|||
return self.__conn.lrem(key, count, value) |
|||
|
|||
def zadd(self, key, values): |
|||
""" |
|||
向 redis 中插入有序集合, |
|||
:param key: 插入的 key |
|||
:param values: 这里的 value 是一个 dict, |
|||
有序集合是根据元素的 score 来进行排序的, |
|||
dict 的 key 是有序集合中的元素, |
|||
dict 的 value 则是元素的 score |
|||
:return: |
|||
""" |
|||
self.__conn.zadd(key, values) |
|||
|
|||
def get_zmenber(self, key, mins, maxs): |
|||
""" |
|||
获取有序集合中排名在某个范围的元素 |
|||
当前该方法是为了 IP 黑名单设计的,后续可以根据其他需求重构 |
|||
:param key: 要获取的有序集合的 key |
|||
:param mins: score 下限 |
|||
:param maxs: score 上限 |
|||
:return: 该有序集合中的某个 score 范围内的元素集合 |
|||
""" |
|||
inner_mins = mins - 60 * 60 * 24 * 7 |
|||
self.__conn.zremrangebyscore(key, inner_mins, mins - 1) |
|||
result = self.__conn.zrangebyscore(key, mins, maxs) |
|||
out = {i.decode(encoding="utf-8") for i in result} |
|||
return out |
|||
|
|||
def exists(self, key): |
|||
return self.__conn.exists(key) |
|||
|
|||
def expire(self, key, timeout): |
|||
self.__conn.expire(key, timeout) |
|||
|
|||
def json_set(self, key, path, *json_data, timeout=None, flag=None): |
|||
""" |
|||
该方法必须要搭配 RedisJSON 模块使用 |
|||
该方法是向 redis 中添加或更新 json 数据 |
|||
当前该方法是为了 log 日志设计 |
|||
:param key: json 的 key |
|||
:param path: json 的节点,是个字符串类型 |
|||
:param json_data: dict 类型,是要添加进 redis 的值 |
|||
:param timeout: key 的过期时间 |
|||
:param flag: |
|||
:return: |
|||
""" |
|||
try: |
|||
if flag: |
|||
self.__conn.json().set(key, ".", {path[1:]: json_data[0]}) |
|||
return |
|||
self.__conn.json().arrappend(key, path, *json_data) |
|||
except ResponseError: |
|||
try: |
|||
self.__conn.json().set(key, path, [json_data[0]]) |
|||
except ResponseError: |
|||
self.__conn.json().set(key, ".", {path[1:]: [json_data[0]]}) |
|||
if len(json_data) >= 2: |
|||
self.json_set(key, path, *json_data[1:]) |
|||
finally: |
|||
if timeout: |
|||
self.__conn.expire(key, timeout) |
|||
|
|||
def json_get(self, key, *path): |
|||
try: |
|||
objkeys = self.__conn.json().objkeys(key) |
|||
except Exception as _: |
|||
objkeys = [] |
|||
if not objkeys: |
|||
objkeys = [] |
|||
path = set(path) & set(objkeys) |
|||
result = self.__conn.json().get(key, *path) |
|||
return result |
|||
|
|||
def close(self): |
|||
""" |
|||
关闭 redis 连接 |
|||
:return: |
|||
""" |
|||
if self.__conn: |
|||
self.__conn.close() |
|||
|
|||
def __del__(self): |
|||
self.close() |
|||
|
|||
def pub(self, channel, message, **kwargs): |
|||
self.__conn.publish(channel, message, **kwargs) |
|||
|
|||
def pubsub(self, **kwargs): |
|||
return self.__conn.pubsub(**kwargs) |
Loading…
Reference in new issue