Browse Source

added class view decortors

dev
Jingxun 8 months ago
parent
commit
8e7917d2e3
  1. 87
      backend/decorators/cbv_decorator.py

87
backend/decorators/cbv_decorator.py

@ -0,0 +1,87 @@
# encoding: utf-8
"""
@author: Qiancj
@contact: qiancj@risenenergy.com
@file: cbv
@create-time: 2023-09-25 15:42
@description: The new python script
"""
import inspect
from typing import get_type_hints
from fastapi import APIRouter, Depends
from pydantic.typing import is_classvar
from starlette.routing import Route, WebSocketRoute
CBV_CLASS_KEY = "__cbv_class__"
def _update_cbv_route_endpoint_signature(cls, route):
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)
def _init_cbv(cls):
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover
return # Already initialized
old_init = cls.__init__
old_signature = inspect.signature(old_init)
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter
new_parameters = [
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]
dependency_names = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
continue
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)}
dependency_names.append(name)
new_parameters.append(
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs)
)
new_signature = old_signature.replace(parameters=new_parameters)
def new_init(self, *args, **kwargs):
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)
setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)
def _cbv(router, cls):
_init_cbv(cls)
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
functions_set = set(func for _, func in function_members)
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
return cls
def cbv(router):
def decorator(cls):
return _cbv(router, cls)
return decorator
Loading…
Cancel
Save