独角鲸同步合作方公司数据项目
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1017 lines
36 KiB

10 months ago
import calendar
import hashlib
import json
import logging
import math
import os
import random
import re
import time
import traceback
from binascii import b2a_hex, a2b_hex
from datetime import timedelta, datetime, date
from decimal import Decimal, ROUND_HALF_UP
from typing import List, Any
import unicodedata
from Crypto.Cipher import AES
from dateutil.relativedelta import relativedelta
from django.conf import settings
from django.core.paginator import InvalidPage
from django.db import connection
from django.db.models import Q
from django.shortcuts import get_object_or_404
from django_filters import rest_framework
from rest_framework import filters
from rest_framework import serializers
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound, Throttled
from rest_framework.generics import ListAPIView
from rest_framework.pagination import PageNumberPagination
from rest_framework.permissions import BasePermission
from rest_framework.permissions import IsAuthenticated
from rest_framework.settings import api_settings
from rest_framework.views import exception_handler
from rest_framework.viewsets import ModelViewSet
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
import ChaCeRndTrans.code
from ChaCeRndTrans.basic import CCAIResponse
from ChaCeRndTrans.code import *
from hashids import Hashids
error_logger = logging.getLogger("error")
info_logger = logging.getLogger("info")
# 初始化 Hashids 对象
salt = settings.ID_KEY # 盐值
min_length = 16 # 生成的最小长度
ChaCeHashids = Hashids(salt=salt, min_length=min_length)
class CustomViewBase(ModelViewSet):
# pagination_class = LargeResultsSetPagination
# filter_class = ServerFilter
queryset = ''
serializer_class = ''
permission_classes = ()
filter_fields = ()
search_fields = ()
filter_backends = (rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter,)
def create(self, request, *args, **kwargs):
data = req_operate_by_user(request)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
# headers = self.get_success_headers(serializer.data)
return CCAIResponse(data="success")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return CCAIResponse(data=serializer.data)
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return CCAIResponse(data=serializer.data)
def update(self, request, *args, **kwargs):
data = req_operate_by_user(request)
partial = kwargs.pop('partial', False) # True:所有字段全部更新, False:仅更新提供的字段
instance = self.get_object()
serializer = self.get_serializer(instance, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return CCAIResponse(data="success")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return CCAIResponse(data="delete resource success")
@action(methods=['delete'], detail=False)
def multiple_delete(self, request, *args, **kwargs):
try:
delete_id = request.query_params.get('ids', None)
if not delete_id:
return CCAIResponse("参数不对啊!", NOT_FOUND)
for i in delete_id.split(','):
get_object_or_404(self.queryset, pk=int(i)).delete()
return CCAIResponse("批量删除成功", OK)
except Exception as e:
error_logger.error("multiple delete crawler news failed: \n%s" % traceback.format_exc())
return CCAIResponse("批量删除失败", SERVER_ERROR)
class CustomHashViewBase(ModelViewSet):
# pagination_class = LargeResultsSetPagination
# filter_class = ServerFilter
queryset = ''
serializer_class = ''
permission_classes = ()
filter_fields = ()
search_fields = ()
filter_backends = (rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter,)
def create(self, request, *args, **kwargs):
data = req_operate_by_user(request)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
# headers = self.get_success_headers(serializer.data)
return CCAIResponse(data="success")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
data = serializer.data
for item in data:
item['id'] = ChaCeHashids.encode(item['id'])
return self.get_paginated_response(data)
serializer = self.get_serializer(queryset, many=True)
data = serializer.data
for item in data:
item['id'] = ChaCeHashids.encode(item['id'])
return CCAIResponse(data=data)
def retrieve(self, request, *args, **kwargs):
encoded_id = kwargs.get('pk')
decoded_id = ChaCeHashids.decode(encoded_id)
if not decoded_id:
return CCAIResponse(data="Invalid ID", status=BAD)
instance = self.get_queryset().get(pk=decoded_id[0])
serializer = self.get_serializer(instance)
return CCAIResponse(data=serializer.data)
def update(self, request, *args, **kwargs):
data = req_operate_by_user(request)
encoded_id = kwargs.get('pk')
decoded_id = ChaCeHashids.decode(encoded_id)
if not decoded_id:
return CCAIResponse(data="Invalid ID", status=BAD)
partial = kwargs.pop('partial', False)
instance = self.get_queryset().get(pk=decoded_id[0])
serializer = self.get_serializer(instance, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
instance._prefetched_objects_cache = {}
return CCAIResponse(data="success")
def destroy(self, request, *args, **kwargs):
encoded_id = kwargs.get('pk')
decoded_id = ChaCeHashids.decode(encoded_id)
if not decoded_id:
return CCAIResponse(data="Invalid ID", status=BAD)
instance = self.get_queryset().get(pk=decoded_id[0])
self.perform_destroy(instance)
return CCAIResponse(data="delete resource success")
@action(methods=['delete'], detail=False)
def multiple_delete(self, request, *args, **kwargs):
try:
delete_id = request.query_params.get('ids', None)
if not delete_id:
return CCAIResponse("参数不对啊!", status=NOT_FOUND)
for encoded_id in delete_id.split(','):
decoded_id = ChaCeHashids.decode(encoded_id)
if not decoded_id:
return CCAIResponse("无效的ID: {}".format(encoded_id), status=NOT_FOUND)
get_object_or_404(self.queryset, pk=decoded_id[0]).delete()
return CCAIResponse("批量删除成功", status=OK)
except Exception as e:
error_logger.error("multiple delete failed: \n%s" % traceback.format_exc())
return CCAIResponse("批量删除失败", status=SERVER_ERROR)
def chacerde_exception_handler(exc, context):
response = exception_handler(exc, context)
# 限流异常
if isinstance(exc, Throttled):
msg = "失败" if response.status_code >= 400 else "成功"
notification_response = {}
notification_response["code"] = response.status_code
notification_response["message"] = msg
notification_response["detail"] = response.data
notification_response['wait'] = exc.wait
response.data = notification_response
return response
if response is not None:
msg = "失败" if response.status_code >= 400 else "成功"
notification_response = {}
notification_response["code"] = response.status_code
notification_response["message"] = msg
notification_response["detail"] = response.data
response.data = notification_response
return response
class CommonPagination(PageNumberPagination):
"""
分页设置
"""
# 默认每页显示的数据条数
page_size = 10
# 获取url参数中设置的每页显示数据条数
page_size_query_param = settings.MY_PAGE_SIZE_QUERY_PARAM
# 获取url中传入的页码key
page_query_param = settings.MY_PAGE_QUERY_PARAM
# 最大支持的每页显示的数据条数
max_page_size = settings.MY_MAX_PAGE_SIZE
"""
自定义分页方法
"""
def get_paginated_response(self, data):
"""
设置返回内容格式
"""
return CCAIResponse(
count=self.page.paginator.count,
data=data)
"""
自定义列表查询方法
当前端请求的页数超过实际有效的页数时直接返回最后一页的数据
"""
def paginate_queryset(self, queryset, request, view=None):
"""
Paginate a queryset if required, either returning a
page object, or `None` if pagination is not configured for this view.
"""
page_size = self.get_page_size(request)
if not page_size:
return None
# 对queryset进行排序, 根据id升序排列
# queryset = queryset.order_by('id')
paginator = self.django_paginator_class(queryset, page_size)
page_number = request.query_params.get(self.page_query_param, 1)
count = paginator.count # 列表累计的总数
if page_number in self.last_page_strings:
page_number = paginator.num_pages
# paginator.num_pages:最大的页数
if count and int(page_number) and (int(page_number) > paginator.num_pages):
self.page = paginator.page(paginator.num_pages) # 返回最后一页的列表
else:
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
msg = self.invalid_page_message.format(
page_number=page_number, message=str(exc)
)
raise NotFound(msg)
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
self.request = request
return list(self.page)
class RbacPermission(BasePermission):
"""
自定义权限
"""
# @classmethod
# def get_permission_from_role(self, request):
# try:
# perms = request.user.roles.values(
# "permissions__method",
# ).distinct()
# return [p["permissions__method"] for p in perms]
# except AttributeError:
# return None
@classmethod
def get_permission_from_role(self, request):
"""
根据用户角色与公司id,返回对应的权限
"""
try:
if request.user:
perms_list = []
# for item in request.user.roles.values("permissions__method").distinct():
# perms_list.append(item["permissions__method"])
companyMid = request.query_params.get('companyMid', None)
if not companyMid:
for item in request.user.roles.values("permissions__method").distinct():
perms_list.append(item["permissions__method"])
else:
for item in request.user.roles.filter(Q(companyMid=companyMid) | Q(companyMid__isnull=True)).values("permissions__method").distinct():
perms_list.append(item["permissions__method"])
return perms_list
except AttributeError:
return None
# @classmethod
# def get_permission_from_grouprole(self, request):
# """
# 从grouprole在获取role然后再获取permission
# """
# try:
# if request.user:
# perms_list = []
# roleid_set = set()
# for item in request.user.grouprole.values("roles").distinct():
# if item["roles"] and item["roles"] != '':
# roleid_list = item["roles"].split(',')
# roleid_set.update(roleid_list)
# if roleid_set:
# from rbac.models import Role # 解决循环依赖的问题
# for item in Role.objects.filter(id__in=roleid_set).values("permissions__method").distinct():
# perms_list.append(item["permissions__method"])
# return perms_list
# except AttributeError:
# return None
def has_permission(self, request, view):
perms = self.get_permission_from_role(request)
if perms:
if "admin" in perms:
return True
elif not hasattr(view, "perms_map"):
return True
else:
perms_map = view.perms_map
_method = request._request.method.lower()
for i in perms_map:
for method, alias in i.items():
if (_method == method or method == "*") and alias in perms:
return True
class ObjPermission(BasePermission):
"""
密码管理对象级权限控制
"""
def has_object_permission(self, request, view, obj):
perms = RbacPermission.get_permission_from_role(request)
if "admin" in perms:
return True
elif request.user.id == obj.uid_id:
return True
class TreeSerializer(serializers.Serializer):
id = serializers.IntegerField()
label = serializers.CharField(max_length=20, source="name")
pid = serializers.PrimaryKeyRelatedField(read_only=True)
class TreeAPIView(ListAPIView):
"""
自定义树结构View
"""
serializer_class = TreeSerializer
authentication_classes = (JSONWebTokenAuthentication,)
permission_classes = (IsAuthenticated,)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(queryset, many=True)
tree_dict = {}
tree_data = []
try:
for item in serializer.data:
tree_dict[item["id"]] = item
for i in tree_dict:
if tree_dict[i]["pid"]:
pid = tree_dict[i]["pid"]
try:
parent = tree_dict[pid]
except KeyError:
# 缺少父级菜单
continue
parent.setdefault("children", []).append(tree_dict[i])
else:
tree_data.append(tree_dict[i])
results = tree_data
except KeyError:
results = serializer.data
if page is not None:
return self.get_paginated_response(results)
return CCAIResponse(results)
def req_operate_by_user(request):
data = request.data.copy()
if request.method == "POST":
data['CreateBy'] = request.user.name
data['CreateByUid'] = request.user.id
data['UpdateBy'] = request.user.name
data['UpdateByUid'] = request.user.id
elif request.method == "PUT":
data['UpdateBy'] = request.user.name
data['UpdateByUid'] = request.user.id
return data
def sha1_encrypt(data):
"""
使用sha1加密算法返回str加密后的字符串
"""
sha = hashlib.sha1(data.encode('utf-8'))
encrypts = sha.hexdigest()
return encrypts
def generate_random_str(randomlength=10):
"""
生成一个指定长度的随机字符串
"""
random_str = ''
base_str = 'ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789*@#$%^-~+'
length = len(base_str) - 1
for i in range(randomlength):
random_str += base_str[random.randint(0, length - 1)]
return random_str
def is_connection_usable():
try:
connection.connection.ping()
except:
return False
else:
return True
def generate_random_str_for_fileName(randomlength=10):
"""
生成一个指定长度的随机字符串, 用于随机生成文件名不包含特殊符号
"""
random_str = ''
base_str = 'ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789'
length = len(base_str) - 1
for i in range(randomlength):
random_str += base_str[random.randint(0, length - 1)]
return random_str
def generate_random_str_16_system(randomlength=16):
"""
生成一个指定长度的随机字符串, 这个字符串看起来像16进制的
"""
random_str = ''
base_str = '0123456789abcdef'
length = len(base_str) - 1
for i in range(randomlength):
random_str += base_str[random.randint(0, length)]
return random_str
def generate_random_pwd(the_length=9):
"""生成指定长度的随机明文密码"""
if the_length < 9:
the_length = 9
special_str = "!@#$%^&*+"
special_str_length = random.randint(1, 2) # 特殊字符的长度
nums = "1234567890"
nums_length = random.randint(3, 4) # 特殊字符的长度
zi_mu = "ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz"
pwd = ""
for i in range(special_str_length):
pwd += special_str[random.randint(0, len(special_str) - 1)]
for i in range(nums_length):
pwd += nums[random.randint(0, len(nums) - 1)]
for i in range(the_length - special_str_length - nums_length):
pwd += zi_mu[random.randint(0, len(zi_mu) - 1)]
new_pwd = ""
ran = random.sample(range(0, the_length), the_length) # 随机生成9个(0, 9)范围内不重复的数字:不含9
for i in ran:
new_pwd += pwd[i]
return new_pwd
# X-Forwarded-For:简称XFF头,它代表客户端,也就是HTTP的请求端真实的IP,只有在通过了HTTP 代理或者负载均衡服务器时才会添加该项。
def get_ip(request):
'''获取请求者的IP信息'''
try:
remote_addr = request.META.get('HTTP_X_REAL_IP')
if remote_addr:
return remote_addr
else:
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
except Exception as e:
return ''
# 生成操作记录存到数据表
def create_operation_history_log(request, info, TheModel):
try:
if request:
ip = get_ip(request)
if request.user.id:
TheModel.objects.create(ip=ip, des=info["des"], user_id=request.user.id,
username=request.user.username, detail=info["detail"])
else:
TheModel.objects.create(ip=ip, des=info["des"],
username="智云用户", detail=info["detail"])
else:
TheModel.objects.create(des=info["des"], username="task", detail=info["detail"])
except Exception as e:
error_logger.error("user: %s, create history failed \n%s" % (request.user.id, traceback.format_exc()))
# 检测密码复杂度
def check_password_complexity(password):
must_str = [".", "!", "`", "/" "?", "@", "#", "$", "%", "^", "&", "*", "(", ")", "+", "-"] # 必须包含这其中之二的字符
must_num = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] # 必须1个
must_abc = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t",
"u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N",
"O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] # 必须1个
password_list = list(password)
if len(password_list) < 8:
return {"flag": False, "message": "密码长度至少8个字符及以上"}
elif len(password_list) > 25:
return {"flag": False, "message": "密码长度至多25个字符及以下"}
must_str_nums = 0
must_num_nums = 0
must_abc_nums = 0
for each in password_list:
if each in must_str:
must_str_nums += 1
elif each in must_num:
must_num_nums += 1
elif each in must_abc:
must_abc_nums += 1
else:
return {"flag": False, "message": "密码不允许存在其他不规范字符"}
if must_str_nums < 2:
return {"flag": False, "message": "密码必须包含2个英文特殊字符及以上如:.!`/?@#$%^&*()+-"}
if must_num_nums < 1:
return {"flag": False, "message": "密码必须包含数字"}
if must_abc_nums < 1:
return {"flag": False, "message": "密码必须包含字母"}
return {"flag": True, "message": "success"}
# 惩罚记录函数
def record_punishment(data, request, PunishmentInfoModel):
record = PunishmentInfoModel.objects.filter(user_id=request.user.id, new_paper_id=data["new_paper_id"]).first()
if not record:
batch = None
if request.user.batch:
batch = request.user.batch.name
session = PunishmentInfoModel.objects.filter(user_id=request.user.id,
new_exam_id=data["new_exam_id"]).count() # 场次
data["exam_name"] = data["exam_name"] + "-第" + str(session + 1) + "次测试"
PunishmentInfoModel.objects.create(user_id=request.user.id, username=request.user.username,
exam_id=data["exam_id"], new_exam_id=data["new_exam_id"],
exam_name=data["exam_name"], paper_id=data["paper_id"],
new_paper_id=data["new_paper_id"], warning_count=data["warning_count"],
max_warning_count=data["max_warning_count"], batch=batch,
is_punishment=data["is_punishment"], name=request.user.name)
return
record.warning_count = data["warning_count"]
record.is_punishment = data["is_punishment"]
record.save()
return
# AES加密类
class AESEnDECryptRelated:
def __init__(self):
self.model = AES.MODE_CBC
# 如果text不足16位的倍数就用空格补足为16位
# 不同于JS,pycryptodome库中加密方法不做任何padding,因此需要区分明文是否为中文的情况
def add_to_16(self, text):
pad = 16 - len(text.encode('utf-8')) % 16
text = text + pad * chr(pad)
return text.encode('utf-8')
# 加密函数
def encrypt(self, text, key, iv):
text = self.add_to_16(text)
cryptos = AES.new(key, self.model, iv)
cipher_text = cryptos.encrypt(text)
return b2a_hex(cipher_text).decode('utf-8')
# 解密函数
def decrypt(self, text):
# 对应加密对象需要转成字符串,然后加密,解密后还得需要用这个函数去掉最后的一些字符才能loads为最终的结果
unpad = lambda s: s[0:-ord(s[-1])]
ciphertext, key_str, iv_str = self.get_decrypt_info(text)
text = a2b_hex(ciphertext)
iv = iv_str.encode('utf-8')
key = key_str.encode('utf-8')
cryptos = AES.new(key, self.model, iv)
plain_text = cryptos.decrypt(text)
# print(json.loads(unpad(plain_text)))
return bytes.decode(plain_text)
def get_decrypt_info(self, text):
text = text[:len(text) - 2]
iv = text[:8] + text[-8:]
key = text[-16:-8] + text[8:16]
ciphertext = text[16:-16]
return ciphertext, key, iv
# 重组新的密文
def combination_new_ciphertext(self, ciphertext, iv_str, key_str):
# 前8个向量 + 后8个密钥 + 密文 + 前8个密钥 + 后8个向量 + ==
new_ciphertext = iv_str[:8] + key_str[8:] + ciphertext + key_str[:8] + iv_str[8:] + "=="
return new_ciphertext
def start_encrypt(self, text):
# 整个加密类的入口
key_str = generate_random_str_16_system(16) # 密钥字符串
key = key_str.encode('utf-8') # byte类型
iv_str = generate_random_str_16_system(16) # 向量字符串
iv = iv_str.encode('utf-8') # byte类型
if not isinstance(text, str):
e1 = self.encrypt(json.dumps(text), key, iv) # 加密
else:
e1 = self.encrypt(text, key, iv) # 加密
new_ciphertext = self.combination_new_ciphertext(e1, iv_str, key_str)
return new_ciphertext
# 验证银行卡号是否有效
def is_valid_debit_card(card_number):
if not card_number.isdigit():
return False
length = len(card_number)
if length < 16 or length > 19:
return False
return True
def generate_random_str(randomlength=16):
"""
生成一个指定长度的随机字符串
"""
random_str = ''
base_str = 'abcdefghigklmnopqrstuvwxyz0123456789'
length = len(base_str) - 1
for i in range(randomlength):
random_str += base_str[random.randint(0, length)]
return random_str
def sha1_encrypt(data):
"""
使用sha1加密算法返回str加密后的字符串
"""
sha = hashlib.sha1(data.encode('utf-8'))
encrypts = sha.hexdigest()
return encrypts
# 验证是否为汉字
def is_all_chinese(s):
if re.match(r'^[\u4e00-\u9fa5]+$', s):
return True
else:
return False
# 验证用户名是否有效,仅为汉字、数字、英文字母,3类
def is_valid_username(username):
pattern = re.compile(r'^[\u4e00-\u9fa5A-Za-z0-9]+$')
return bool(pattern.match(username))
# 计算子账号的分销比例
def calculate_sub_ratio(main_ratio, sub_ratio):
main_ratio = Decimal(main_ratio)
sub_ratio = Decimal(sub_ratio)
result = main_ratio * sub_ratio / Decimal(100)
return result.quantize(Decimal('0.00'), rounding=ROUND_HALF_UP)
# def asyncDeleteFile(request, fileName):
# try:
# time.sleep(1) # 等待文件被读取完毕
# if os.path.exists(fileName): # 如果文件存在
# # 删除文件,可使用以下两种方法。
# os.remove(fileName)
# except Exception as e:
# logging.getLogger('error').error(
# "user: %s, delete FileName: %s file failed: \n%s" % (
# request.user.id, fileName, traceback.format_exc()))
def asyncDeleteFile(request, fileName):
try:
max_attempts = 10 # 设置最大尝试次数
attempt_count = 0
while attempt_count < max_attempts:
try:
os.remove(fileName)
logging.getLogger('info').info('delete success')
parent_dir = os.path.dirname(fileName)
if parent_dir and parent_dir != '' and not os.listdir(parent_dir):
os.rmdir(parent_dir)
break # 文件删除成功,退出循环
except PermissionError as e:
# 如果捕获到 PermissionError 异常,则等待一段时间再次尝试删除文件
logging.getLogger('error').error(f"文件删除失败:{e}")
attempt_count += 1
if attempt_count >= max_attempts:
logging.getLogger('error').error("已达到最大尝试次数,无法删除文件")
break
time.sleep(1) # 等待1秒后再次尝试删除文件
except Exception as e:
logging.getLogger('error').error(
"user: %s, delete FileName: %s file failed: \n%s" % (
request.user.id, fileName, traceback.format_exc()))
class MyCustomError(Exception):
"""
自定义异常
"""
def __init__(self, message):
self.message = message
super().__init__(self.message)
def digit_to_chinese(num):
try:
units = ['', '', '', '']
digits = ['', '', '亿', '']
chinese_digits = ['', '', '', '', '', '', '', '', '', '']
chinese_units = ['', '', '']
result = ''
num_int = int(num)
num_decimal = round((num - num_int) * 100) # 小数部分四舍五入
num_str = str(num_int)
num_str_len = len(num_str)
digit_group_count = (num_str_len + 3) // 4 # 每4位一组
for i in range(digit_group_count):
group_str = num_str[max(0, num_str_len - (i + 1) * 4):num_str_len - i * 4]
group_int = int(group_str)
if group_int == 0:
continue
# 三位一组处理
group_result = ''
for j in range(len(group_str)):
digit = int(group_str[j])
if digit != 0:
group_result += chinese_digits[digit] + units[len(group_str) - j - 1]
# 加入“万”、“亿”等单位
group_result += digits[i]
result = group_result + result
# 添加“元”、“角”、“分”
if result != '':
result += ''
if num_decimal > 0:
result += chinese_digits[num_decimal // 10] + chinese_units[1] + chinese_digits[num_decimal % 10] + \
chinese_units[2]
else:
result += ''
return result
except Exception as e:
logging.getLogger('error').error(
"digit: %s changeTo chinese failed: \n%s" % (num, traceback.format_exc()))
return 0
def get_first_and_last_day(year_month):
"""
获取某个月份的第一天与最后一天
"""
year, month = map(int, year_month.split('-')[:2])
first_day = datetime(year, month, 1)
if month == 12:
last_day = datetime(year + 1, 1, 1) - timedelta(days=1)
else:
last_day = datetime(year, month + 1, 1) - timedelta(days=1)
return first_day.date(), last_day.date()
def get_all_months(start_date, end_date):
"""根据两个时间,求出中间的年月列表"""
months = []
if type(start_date) is not date:
start_date = datetime.strptime(start_date, '%Y-%m-%d')
if type(end_date) is not date:
end_date = datetime.strptime(end_date, '%Y-%m-%d')
current_date = start_date.replace(day=1) # 将日期设置为该月的第一天
while current_date <= end_date:
months.append(current_date.strftime('%Y-%m'))
current_date += relativedelta(months=1) # 增加一个月
return months
def calculate_hours(AmStart, AmEnd, PmStart, PmEnd, Separate=False): # Separate:是否分割上下午
"""计算公司上班时长"""
time_format = '%H:%M'
am_start_time = datetime.strptime(AmStart, time_format)
am_end_time = datetime.strptime(AmEnd, time_format)
pm_start_time = datetime.strptime(PmStart, time_format)
pm_end_time = datetime.strptime(PmEnd, time_format)
# am_hours = (am_end_time - am_start_time).seconds / 3600 + (am_end_time - am_start_time).minutes / 60 / 2
am_hours, remainder = divmod((am_end_time - am_start_time).seconds, 1800)
if remainder > 0:
am_hours += 1
# pm_hours = (pm_end_time - pm_start_time).seconds / 3600 + (pm_end_time - pm_start_time).minutes / 60 / 2
pm_hours, remainder = divmod((pm_end_time - pm_start_time).seconds, 1800)
if remainder > 0:
pm_hours += 1
if Separate: # 分割上下午分别返回
return am_hours / 2, pm_hours / 2
return (am_hours + pm_hours) / 2
def get_month_days(year_month):
"""获取该月天数"""
year, month = map(int, year_month.split('-'))
days_in_month = calendar.monthrange(year, month)[1]
return days_in_month
class ErrorTable:
"""
错误信息表格工具
"""
def __init__(self, column_names: List[str]):
if 0 == len(column_names):
raise ValueError('初始化列名必填')
self.column_names = column_names + ['ErrorMessage']
self.data = []
def has_data(self) -> bool:
return bool(self.data)
def add_one_row(self, row_data: List[Any], error_message: str = None):
if len(row_data) != len(self.column_names) - 1:
raise MyCustomError("添加错误信息行失败,表格列数与输入列数不符")
try:
if error_message is None:
error_message = ""
row_data.append(error_message)
row_dict = {}
for index, value in enumerate(row_data):
row_dict[self.column_names[index]] = value
self.data.append(row_dict)
except KeyError:
logging.getLogger('error').error(
"ErrorTable add_row failed: \n%s" % (traceback.format_exc()))
def get_table(self) -> dict:
table_dict = {'name': [name for name in self.column_names], 'data': self.data}
return table_dict
def common_update_sql(item, table_name):
# 生成更新sql的where之前的语句:item为字典,key是更新的字段名,value是字段值,table_name为表名
update = ','.join([" {key} = %s".format(key=key) for key in item])
update_sql = 'update {table} set '.format(table=table_name) + update
return update_sql
def remove_invisible_characters(value):
# 使用正则表达式匹配所有不可见字符并将其替换为空字符串
value = re.sub(r'\s+', '', value)
return value
def transform_characters(value):
# 转换为英文半角字符
# 转换为英文半角字符
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
return value
def count_1(n: int) -> int:
"""
计算打卡次数
@param n: 二进制打卡记录对应的整数
@return: 该整数对应的打卡次数二进制中1的数量
"""
count = 0
while n:
count += n & 1
n >>= 1
return count
def get_month_last(_date: date, length: int) -> list[str]:
"""
获取参数时间前length个月不包含本月 返回字符串数组
@param _date: 日期
@param length: 期望月数
@return: 参数时间前length个月
"""
current_month = _date.month
current_year = _date.year
months = []
for i in range(length):
if current_month == 1:
month = 12
year = current_year - 1
else:
month = current_month - 1
year = current_year
months.append(f"{year}-{month:02d}")
current_month = month
current_year = year
return months
def round_to_highest_digit(number):
"""
保留最高位取整
@param number:
@return:
"""
if number == 0:
return 0
# 获取数字的位数
digits = int(math.log10(abs(number)))
# 计算最高位的值
highest_digit_value = int(number / (10 ** digits))
# 将最高位值取整,然后还原成最终的值
result = round(highest_digit_value) * (10 ** digits)
return result
def mul_split(text):
"""
多字符同时划分
"""
try:
split_string = ' ,.; ,。;' # 定义分隔符
result = re.split(r'[' + re.escape(split_string) + ']', text)
result = [item.strip() for item in result if item] # 移除空字符串和空白项
return result
except Exception:
error_logger.error("text: %s, mul_split failed: %s" % (text, traceback.format_exc()))
return
def time_str_to_int(time_str):
"""分秒字符串转int"""
minutes, seconds = map(int, time_str.split(':'))
return minutes * 60 + seconds
def time_int_to_str(total_seconds):
"""int转分秒字符串"""
if 0 == total_seconds:
return '00:00'
minutes, seconds = divmod(total_seconds, 60)
return f"{minutes:02d}:{seconds:02d}"