import calendar import hashlib import json import logging import math import os import random import re import time import traceback from binascii import b2a_hex, a2b_hex from datetime import timedelta, datetime, date from decimal import Decimal, ROUND_HALF_UP from typing import List, Any import unicodedata from Crypto.Cipher import AES from dateutil.relativedelta import relativedelta from django.conf import settings from django.core.paginator import InvalidPage from django.db import connection from django.db.models import Q from django.shortcuts import get_object_or_404 from django_filters import rest_framework from rest_framework import filters from rest_framework import serializers from rest_framework.decorators import action from rest_framework.exceptions import NotFound, Throttled from rest_framework.generics import ListAPIView from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import BasePermission from rest_framework.permissions import IsAuthenticated from rest_framework.settings import api_settings from rest_framework.views import exception_handler from rest_framework.viewsets import ModelViewSet from rest_framework_jwt.authentication import JSONWebTokenAuthentication import ChaCeRndTrans.code from ChaCeRndTrans.basic import CCAIResponse from ChaCeRndTrans.code import * from hashids import Hashids error_logger = logging.getLogger("error") info_logger = logging.getLogger("info") # 初始化 Hashids 对象 salt = settings.ID_KEY # 盐值 min_length = 16 # 生成的最小长度 ChaCeHashids = Hashids(salt=salt, min_length=min_length) class CustomViewBase(ModelViewSet): # pagination_class = LargeResultsSetPagination # filter_class = ServerFilter queryset = '' serializer_class = '' permission_classes = () filter_fields = () search_fields = () filter_backends = (rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter,) def create(self, request, *args, **kwargs): data = req_operate_by_user(request) serializer = self.get_serializer(data=data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) # headers = self.get_success_headers(serializer.data) return CCAIResponse(data="success") def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(queryset, many=True) return CCAIResponse(data=serializer.data) def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return CCAIResponse(data=serializer.data) def update(self, request, *args, **kwargs): data = req_operate_by_user(request) partial = kwargs.pop('partial', False) # True:所有字段全部更新, False:仅更新提供的字段 instance = self.get_object() serializer = self.get_serializer(instance, data=data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} return CCAIResponse(data="success") def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) return CCAIResponse(data="delete resource success") @action(methods=['delete'], detail=False) def multiple_delete(self, request, *args, **kwargs): try: delete_id = request.query_params.get('ids', None) if not delete_id: return CCAIResponse("参数不对啊!", NOT_FOUND) for i in delete_id.split(','): get_object_or_404(self.queryset, pk=int(i)).delete() return CCAIResponse("批量删除成功", OK) except Exception as e: error_logger.error("multiple delete crawler news failed: \n%s" % traceback.format_exc()) return CCAIResponse("批量删除失败", SERVER_ERROR) class CustomHashViewBase(ModelViewSet): # pagination_class = LargeResultsSetPagination # filter_class = ServerFilter queryset = '' serializer_class = '' permission_classes = () filter_fields = () search_fields = () filter_backends = (rest_framework.DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter,) def create(self, request, *args, **kwargs): data = req_operate_by_user(request) serializer = self.get_serializer(data=data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) # headers = self.get_success_headers(serializer.data) return CCAIResponse(data="success") def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) data = serializer.data for item in data: item['id'] = ChaCeHashids.encode(item['id']) return self.get_paginated_response(data) serializer = self.get_serializer(queryset, many=True) data = serializer.data for item in data: item['id'] = ChaCeHashids.encode(item['id']) return CCAIResponse(data=data) def retrieve(self, request, *args, **kwargs): encoded_id = kwargs.get('pk') decoded_id = ChaCeHashids.decode(encoded_id) if not decoded_id: return CCAIResponse(data="Invalid ID", status=BAD) instance = self.get_queryset().get(pk=decoded_id[0]) serializer = self.get_serializer(instance) return CCAIResponse(data=serializer.data) def update(self, request, *args, **kwargs): data = req_operate_by_user(request) encoded_id = kwargs.get('pk') decoded_id = ChaCeHashids.decode(encoded_id) if not decoded_id: return CCAIResponse(data="Invalid ID", status=BAD) partial = kwargs.pop('partial', False) instance = self.get_queryset().get(pk=decoded_id[0]) serializer = self.get_serializer(instance, data=data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) if getattr(instance, '_prefetched_objects_cache', None): instance._prefetched_objects_cache = {} return CCAIResponse(data="success") def destroy(self, request, *args, **kwargs): encoded_id = kwargs.get('pk') decoded_id = ChaCeHashids.decode(encoded_id) if not decoded_id: return CCAIResponse(data="Invalid ID", status=BAD) instance = self.get_queryset().get(pk=decoded_id[0]) self.perform_destroy(instance) return CCAIResponse(data="delete resource success") @action(methods=['delete'], detail=False) def multiple_delete(self, request, *args, **kwargs): try: delete_id = request.query_params.get('ids', None) if not delete_id: return CCAIResponse("参数不对啊!", status=NOT_FOUND) for encoded_id in delete_id.split(','): decoded_id = ChaCeHashids.decode(encoded_id) if not decoded_id: return CCAIResponse("无效的ID: {}".format(encoded_id), status=NOT_FOUND) get_object_or_404(self.queryset, pk=decoded_id[0]).delete() return CCAIResponse("批量删除成功", status=OK) except Exception as e: error_logger.error("multiple delete failed: \n%s" % traceback.format_exc()) return CCAIResponse("批量删除失败", status=SERVER_ERROR) def chacerde_exception_handler(exc, context): response = exception_handler(exc, context) # 限流异常 if isinstance(exc, Throttled): msg = "失败" if response.status_code >= 400 else "成功" notification_response = {} notification_response["code"] = response.status_code notification_response["message"] = msg notification_response["detail"] = response.data notification_response['wait'] = exc.wait response.data = notification_response return response if response is not None: msg = "失败" if response.status_code >= 400 else "成功" notification_response = {} notification_response["code"] = response.status_code notification_response["message"] = msg notification_response["detail"] = response.data response.data = notification_response return response class CommonPagination(PageNumberPagination): """ 分页设置 """ # 默认每页显示的数据条数 page_size = 10 # 获取url参数中设置的每页显示数据条数 page_size_query_param = settings.MY_PAGE_SIZE_QUERY_PARAM # 获取url中传入的页码key page_query_param = settings.MY_PAGE_QUERY_PARAM # 最大支持的每页显示的数据条数 max_page_size = settings.MY_MAX_PAGE_SIZE """ 自定义分页方法 """ def get_paginated_response(self, data): """ 设置返回内容格式 """ return CCAIResponse( count=self.page.paginator.count, data=data) """ 自定义列表查询方法; 当前端请求的页数超过实际有效的页数时,直接返回最后一页的数据 """ def paginate_queryset(self, queryset, request, view=None): """ Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view. """ page_size = self.get_page_size(request) if not page_size: return None # 对queryset进行排序, 根据id升序排列 # queryset = queryset.order_by('id') paginator = self.django_paginator_class(queryset, page_size) page_number = request.query_params.get(self.page_query_param, 1) count = paginator.count # 列表累计的总数 if page_number in self.last_page_strings: page_number = paginator.num_pages # paginator.num_pages:最大的页数 if count and int(page_number) and (int(page_number) > paginator.num_pages): self.page = paginator.page(paginator.num_pages) # 返回最后一页的列表 else: try: self.page = paginator.page(page_number) except InvalidPage as exc: msg = self.invalid_page_message.format( page_number=page_number, message=str(exc) ) raise NotFound(msg) if paginator.num_pages > 1 and self.template is not None: # The browsable API should display pagination controls. self.display_page_controls = True self.request = request return list(self.page) class RbacPermission(BasePermission): """ 自定义权限 """ # @classmethod # def get_permission_from_role(self, request): # try: # perms = request.user.roles.values( # "permissions__method", # ).distinct() # return [p["permissions__method"] for p in perms] # except AttributeError: # return None @classmethod def get_permission_from_role(self, request): """ 根据用户角色与公司id,返回对应的权限 """ try: if request.user: perms_list = [] # for item in request.user.roles.values("permissions__method").distinct(): # perms_list.append(item["permissions__method"]) companyMid = request.query_params.get('companyMid', None) if not companyMid: for item in request.user.roles.values("permissions__method").distinct(): perms_list.append(item["permissions__method"]) else: for item in request.user.roles.filter(Q(companyMid=companyMid) | Q(companyMid__isnull=True)).values("permissions__method").distinct(): perms_list.append(item["permissions__method"]) return perms_list except AttributeError: return None # @classmethod # def get_permission_from_grouprole(self, request): # """ # 从grouprole在获取role然后再获取permission # """ # try: # if request.user: # perms_list = [] # roleid_set = set() # for item in request.user.grouprole.values("roles").distinct(): # if item["roles"] and item["roles"] != '': # roleid_list = item["roles"].split(',') # roleid_set.update(roleid_list) # if roleid_set: # from rbac.models import Role # 解决循环依赖的问题 # for item in Role.objects.filter(id__in=roleid_set).values("permissions__method").distinct(): # perms_list.append(item["permissions__method"]) # return perms_list # except AttributeError: # return None def has_permission(self, request, view): perms = self.get_permission_from_role(request) if perms: if "admin" in perms: return True elif not hasattr(view, "perms_map"): return True else: perms_map = view.perms_map _method = request._request.method.lower() for i in perms_map: for method, alias in i.items(): if (_method == method or method == "*") and alias in perms: return True class ObjPermission(BasePermission): """ 密码管理对象级权限控制 """ def has_object_permission(self, request, view, obj): perms = RbacPermission.get_permission_from_role(request) if "admin" in perms: return True elif request.user.id == obj.uid_id: return True class TreeSerializer(serializers.Serializer): id = serializers.IntegerField() label = serializers.CharField(max_length=20, source="name") pid = serializers.PrimaryKeyRelatedField(read_only=True) class TreeAPIView(ListAPIView): """ 自定义树结构View """ serializer_class = TreeSerializer authentication_classes = (JSONWebTokenAuthentication,) permission_classes = (IsAuthenticated,) def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) serializer = self.get_serializer(queryset, many=True) tree_dict = {} tree_data = [] try: for item in serializer.data: tree_dict[item["id"]] = item for i in tree_dict: if tree_dict[i]["pid"]: pid = tree_dict[i]["pid"] try: parent = tree_dict[pid] except KeyError: # 缺少父级菜单 continue parent.setdefault("children", []).append(tree_dict[i]) else: tree_data.append(tree_dict[i]) results = tree_data except KeyError: results = serializer.data if page is not None: return self.get_paginated_response(results) return CCAIResponse(results) def req_operate_by_user(request): data = request.data.copy() if request.method == "POST": data['CreateBy'] = request.user.name data['CreateByUid'] = request.user.id data['UpdateBy'] = request.user.name data['UpdateByUid'] = request.user.id elif request.method == "PUT": data['UpdateBy'] = request.user.name data['UpdateByUid'] = request.user.id return data def sha1_encrypt(data): """ 使用sha1加密算法,返回str加密后的字符串 """ sha = hashlib.sha1(data.encode('utf-8')) encrypts = sha.hexdigest() return encrypts def generate_random_str(randomlength=10): """ 生成一个指定长度的随机字符串 """ random_str = '' base_str = 'ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789*@#$%^-~+' length = len(base_str) - 1 for i in range(randomlength): random_str += base_str[random.randint(0, length - 1)] return random_str def is_connection_usable(): try: connection.connection.ping() except: return False else: return True def generate_random_str_for_fileName(randomlength=10): """ 生成一个指定长度的随机字符串, 用于随机生成文件名,不包含特殊符号 """ random_str = '' base_str = 'ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789' length = len(base_str) - 1 for i in range(randomlength): random_str += base_str[random.randint(0, length - 1)] return random_str def generate_random_str_16_system(randomlength=16): """ 生成一个指定长度的随机字符串, 这个字符串看起来像16进制的 """ random_str = '' base_str = '0123456789abcdef' length = len(base_str) - 1 for i in range(randomlength): random_str += base_str[random.randint(0, length)] return random_str def generate_random_pwd(the_length=9): """生成指定长度的随机明文密码""" if the_length < 9: the_length = 9 special_str = "!@#$%^&*+" special_str_length = random.randint(1, 2) # 特殊字符的长度 nums = "1234567890" nums_length = random.randint(3, 4) # 特殊字符的长度 zi_mu = "ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz" pwd = "" for i in range(special_str_length): pwd += special_str[random.randint(0, len(special_str) - 1)] for i in range(nums_length): pwd += nums[random.randint(0, len(nums) - 1)] for i in range(the_length - special_str_length - nums_length): pwd += zi_mu[random.randint(0, len(zi_mu) - 1)] new_pwd = "" ran = random.sample(range(0, the_length), the_length) # 随机生成9个(0, 9)范围内不重复的数字:不含9 for i in ran: new_pwd += pwd[i] return new_pwd # X-Forwarded-For:简称XFF头,它代表客户端,也就是HTTP的请求端真实的IP,只有在通过了HTTP 代理或者负载均衡服务器时才会添加该项。 def get_ip(request): '''获取请求者的IP信息''' try: remote_addr = request.META.get('HTTP_X_REAL_IP') if remote_addr: return remote_addr else: xff = request.META.get('HTTP_X_FORWARDED_FOR') remote_addr = request.META.get('REMOTE_ADDR') num_proxies = api_settings.NUM_PROXIES if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr addrs = xff.split(',') client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() return ''.join(xff.split()) if xff else remote_addr except Exception as e: return '' # 生成操作记录存到数据表 def create_operation_history_log(request, info, TheModel): try: if request: ip = get_ip(request) if request.user.id: TheModel.objects.create(ip=ip, des=info["des"], user_id=request.user.id, username=request.user.username, detail=info["detail"]) else: TheModel.objects.create(ip=ip, des=info["des"], username="智云用户", detail=info["detail"]) else: TheModel.objects.create(des=info["des"], username="task", detail=info["detail"]) except Exception as e: error_logger.error("user: %s, create history failed \n%s" % (request.user.id, traceback.format_exc())) # 检测密码复杂度 def check_password_complexity(password): must_str = [".", "!", "`", "/" "?", "@", "#", "$", "%", "^", "&", "*", "(", ")", "+", "-"] # 必须包含这其中之二的字符 must_num = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] # 必须1个 must_abc = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] # 必须1个 password_list = list(password) if len(password_list) < 8: return {"flag": False, "message": "密码长度至少8个字符及以上"} elif len(password_list) > 25: return {"flag": False, "message": "密码长度至多25个字符及以下"} must_str_nums = 0 must_num_nums = 0 must_abc_nums = 0 for each in password_list: if each in must_str: must_str_nums += 1 elif each in must_num: must_num_nums += 1 elif each in must_abc: must_abc_nums += 1 else: return {"flag": False, "message": "密码不允许存在其他不规范字符"} if must_str_nums < 2: return {"flag": False, "message": "密码必须包含2个英文特殊字符及以上如:.!`/?@#$%^&*()+-"} if must_num_nums < 1: return {"flag": False, "message": "密码必须包含数字"} if must_abc_nums < 1: return {"flag": False, "message": "密码必须包含字母"} return {"flag": True, "message": "success"} # 惩罚记录函数 def record_punishment(data, request, PunishmentInfoModel): record = PunishmentInfoModel.objects.filter(user_id=request.user.id, new_paper_id=data["new_paper_id"]).first() if not record: batch = None if request.user.batch: batch = request.user.batch.name session = PunishmentInfoModel.objects.filter(user_id=request.user.id, new_exam_id=data["new_exam_id"]).count() # 场次 data["exam_name"] = data["exam_name"] + "-第" + str(session + 1) + "次测试" PunishmentInfoModel.objects.create(user_id=request.user.id, username=request.user.username, exam_id=data["exam_id"], new_exam_id=data["new_exam_id"], exam_name=data["exam_name"], paper_id=data["paper_id"], new_paper_id=data["new_paper_id"], warning_count=data["warning_count"], max_warning_count=data["max_warning_count"], batch=batch, is_punishment=data["is_punishment"], name=request.user.name) return record.warning_count = data["warning_count"] record.is_punishment = data["is_punishment"] record.save() return # AES加密类 class AESEnDECryptRelated: def __init__(self): self.model = AES.MODE_CBC # 如果text不足16位的倍数就用空格补足为16位 # 不同于JS,pycryptodome库中加密方法不做任何padding,因此需要区分明文是否为中文的情况 def add_to_16(self, text): pad = 16 - len(text.encode('utf-8')) % 16 text = text + pad * chr(pad) return text.encode('utf-8') # 加密函数 def encrypt(self, text, key, iv): text = self.add_to_16(text) cryptos = AES.new(key, self.model, iv) cipher_text = cryptos.encrypt(text) return b2a_hex(cipher_text).decode('utf-8') # 解密函数 def decrypt(self, text): # 对应加密对象需要转成字符串,然后加密,解密后还得需要用这个函数去掉最后的一些字符才能loads为最终的结果 unpad = lambda s: s[0:-ord(s[-1])] ciphertext, key_str, iv_str = self.get_decrypt_info(text) text = a2b_hex(ciphertext) iv = iv_str.encode('utf-8') key = key_str.encode('utf-8') cryptos = AES.new(key, self.model, iv) plain_text = cryptos.decrypt(text) # print(json.loads(unpad(plain_text))) return bytes.decode(plain_text) def get_decrypt_info(self, text): text = text[:len(text) - 2] iv = text[:8] + text[-8:] key = text[-16:-8] + text[8:16] ciphertext = text[16:-16] return ciphertext, key, iv # 重组新的密文 def combination_new_ciphertext(self, ciphertext, iv_str, key_str): # 前8个向量 + 后8个密钥 + 密文 + 前8个密钥 + 后8个向量 + == new_ciphertext = iv_str[:8] + key_str[8:] + ciphertext + key_str[:8] + iv_str[8:] + "==" return new_ciphertext def start_encrypt(self, text): # 整个加密类的入口 key_str = generate_random_str_16_system(16) # 密钥字符串 key = key_str.encode('utf-8') # byte类型 iv_str = generate_random_str_16_system(16) # 向量字符串 iv = iv_str.encode('utf-8') # byte类型 if not isinstance(text, str): e1 = self.encrypt(json.dumps(text), key, iv) # 加密 else: e1 = self.encrypt(text, key, iv) # 加密 new_ciphertext = self.combination_new_ciphertext(e1, iv_str, key_str) return new_ciphertext # 验证银行卡号是否有效 def is_valid_debit_card(card_number): if not card_number.isdigit(): return False length = len(card_number) if length < 16 or length > 19: return False return True def generate_random_str(randomlength=16): """ 生成一个指定长度的随机字符串 """ random_str = '' base_str = 'abcdefghigklmnopqrstuvwxyz0123456789' length = len(base_str) - 1 for i in range(randomlength): random_str += base_str[random.randint(0, length)] return random_str def sha1_encrypt(data): """ 使用sha1加密算法,返回str加密后的字符串 """ sha = hashlib.sha1(data.encode('utf-8')) encrypts = sha.hexdigest() return encrypts # 验证是否为汉字 def is_all_chinese(s): if re.match(r'^[\u4e00-\u9fa5]+$', s): return True else: return False # 验证用户名是否有效,仅为汉字、数字、英文字母,3类 def is_valid_username(username): pattern = re.compile(r'^[\u4e00-\u9fa5A-Za-z0-9]+$') return bool(pattern.match(username)) # 计算子账号的分销比例 def calculate_sub_ratio(main_ratio, sub_ratio): main_ratio = Decimal(main_ratio) sub_ratio = Decimal(sub_ratio) result = main_ratio * sub_ratio / Decimal(100) return result.quantize(Decimal('0.00'), rounding=ROUND_HALF_UP) # def asyncDeleteFile(request, fileName): # try: # time.sleep(1) # 等待文件被读取完毕 # if os.path.exists(fileName): # 如果文件存在 # # 删除文件,可使用以下两种方法。 # os.remove(fileName) # except Exception as e: # logging.getLogger('error').error( # "user: %s, delete FileName: %s file failed: \n%s" % ( # request.user.id, fileName, traceback.format_exc())) def asyncDeleteFile(request, fileName): try: max_attempts = 10 # 设置最大尝试次数 attempt_count = 0 while attempt_count < max_attempts: try: os.remove(fileName) logging.getLogger('info').info('delete success') parent_dir = os.path.dirname(fileName) if parent_dir and parent_dir != '' and not os.listdir(parent_dir): os.rmdir(parent_dir) break # 文件删除成功,退出循环 except PermissionError as e: # 如果捕获到 PermissionError 异常,则等待一段时间再次尝试删除文件 logging.getLogger('error').error(f"文件删除失败:{e}") attempt_count += 1 if attempt_count >= max_attempts: logging.getLogger('error').error("已达到最大尝试次数,无法删除文件") break time.sleep(1) # 等待1秒后再次尝试删除文件 except Exception as e: logging.getLogger('error').error( "user: %s, delete FileName: %s file failed: \n%s" % ( request.user.id, fileName, traceback.format_exc())) class MyCustomError(Exception): """ 自定义异常 """ def __init__(self, message): self.message = message super().__init__(self.message) def digit_to_chinese(num): try: units = ['', '拾', '佰', '仟'] digits = ['', '万', '亿', '兆'] chinese_digits = ['', '壹', '贰', '叁', '肆', '伍', '陆', '柒', '捌', '玖'] chinese_units = ['', '角', '分'] result = '' num_int = int(num) num_decimal = round((num - num_int) * 100) # 小数部分四舍五入 num_str = str(num_int) num_str_len = len(num_str) digit_group_count = (num_str_len + 3) // 4 # 每4位一组 for i in range(digit_group_count): group_str = num_str[max(0, num_str_len - (i + 1) * 4):num_str_len - i * 4] group_int = int(group_str) if group_int == 0: continue # 三位一组处理 group_result = '' for j in range(len(group_str)): digit = int(group_str[j]) if digit != 0: group_result += chinese_digits[digit] + units[len(group_str) - j - 1] # 加入“万”、“亿”等单位 group_result += digits[i] result = group_result + result # 添加“元”、“角”、“分” if result != '': result += '元' if num_decimal > 0: result += chinese_digits[num_decimal // 10] + chinese_units[1] + chinese_digits[num_decimal % 10] + \ chinese_units[2] else: result += '整' return result except Exception as e: logging.getLogger('error').error( "digit: %s changeTo chinese failed: \n%s" % (num, traceback.format_exc())) return 0 def get_first_and_last_day(year_month): """ 获取某个月份的第一天与最后一天 """ year, month = map(int, year_month.split('-')[:2]) first_day = datetime(year, month, 1) if month == 12: last_day = datetime(year + 1, 1, 1) - timedelta(days=1) else: last_day = datetime(year, month + 1, 1) - timedelta(days=1) return first_day.date(), last_day.date() def get_all_months(start_date, end_date): """根据两个时间,求出中间的年月列表""" months = [] if type(start_date) is not date: start_date = datetime.strptime(start_date, '%Y-%m-%d') if type(end_date) is not date: end_date = datetime.strptime(end_date, '%Y-%m-%d') current_date = start_date.replace(day=1) # 将日期设置为该月的第一天 while current_date <= end_date: months.append(current_date.strftime('%Y-%m')) current_date += relativedelta(months=1) # 增加一个月 return months def calculate_hours(AmStart, AmEnd, PmStart, PmEnd, Separate=False): # Separate:是否分割上下午 """计算公司上班时长""" time_format = '%H:%M' am_start_time = datetime.strptime(AmStart, time_format) am_end_time = datetime.strptime(AmEnd, time_format) pm_start_time = datetime.strptime(PmStart, time_format) pm_end_time = datetime.strptime(PmEnd, time_format) # am_hours = (am_end_time - am_start_time).seconds / 3600 + (am_end_time - am_start_time).minutes / 60 / 2 am_hours, remainder = divmod((am_end_time - am_start_time).seconds, 1800) if remainder > 0: am_hours += 1 # pm_hours = (pm_end_time - pm_start_time).seconds / 3600 + (pm_end_time - pm_start_time).minutes / 60 / 2 pm_hours, remainder = divmod((pm_end_time - pm_start_time).seconds, 1800) if remainder > 0: pm_hours += 1 if Separate: # 分割上下午分别返回 return am_hours / 2, pm_hours / 2 return (am_hours + pm_hours) / 2 def get_month_days(year_month): """获取该月天数""" year, month = map(int, year_month.split('-')) days_in_month = calendar.monthrange(year, month)[1] return days_in_month class ErrorTable: """ 错误信息表格工具 """ def __init__(self, column_names: List[str]): if 0 == len(column_names): raise ValueError('初始化列名必填') self.column_names = column_names + ['ErrorMessage'] self.data = [] def has_data(self) -> bool: return bool(self.data) def add_one_row(self, row_data: List[Any], error_message: str = None): if len(row_data) != len(self.column_names) - 1: raise MyCustomError("添加错误信息行失败,表格列数与输入列数不符") try: if error_message is None: error_message = "" row_data.append(error_message) row_dict = {} for index, value in enumerate(row_data): row_dict[self.column_names[index]] = value self.data.append(row_dict) except KeyError: logging.getLogger('error').error( "ErrorTable add_row failed: \n%s" % (traceback.format_exc())) def get_table(self) -> dict: table_dict = {'name': [name for name in self.column_names], 'data': self.data} return table_dict def common_update_sql(item, table_name): # 生成更新sql的where之前的语句:item为字典,key是更新的字段名,value是字段值,table_name为表名 update = ','.join([" {key} = %s".format(key=key) for key in item]) update_sql = 'update {table} set '.format(table=table_name) + update return update_sql def remove_invisible_characters(value): # 使用正则表达式匹配所有不可见字符并将其替换为空字符串 value = re.sub(r'\s+', '', value) return value def transform_characters(value): # 转换为英文半角字符 # 转换为英文半角字符 value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') return value def count_1(n: int) -> int: """ 计算打卡次数 @param n: 二进制打卡记录对应的整数 @return: 该整数对应的打卡次数(即:二进制中1的数量) """ count = 0 while n: count += n & 1 n >>= 1 return count def get_month_last(_date: date, length: int) -> list[str]: """ 获取参数时间前length个月,不包含本月 返回字符串数组 @param _date: 日期 @param length: 期望月数 @return: 参数时间前length个月 """ current_month = _date.month current_year = _date.year months = [] for i in range(length): if current_month == 1: month = 12 year = current_year - 1 else: month = current_month - 1 year = current_year months.append(f"{year}-{month:02d}") current_month = month current_year = year return months def round_to_highest_digit(number): """ 保留最高位取整 @param number: @return: """ if number == 0: return 0 # 获取数字的位数 digits = int(math.log10(abs(number))) # 计算最高位的值 highest_digit_value = int(number / (10 ** digits)) # 将最高位值取整,然后还原成最终的值 result = round(highest_digit_value) * (10 ** digits) return result def mul_split(text): """ 多字符同时划分 """ try: split_string = ' ,.; ,。;' # 定义分隔符 result = re.split(r'[' + re.escape(split_string) + ']', text) result = [item.strip() for item in result if item] # 移除空字符串和空白项 return result except Exception: error_logger.error("text: %s, mul_split failed: %s" % (text, traceback.format_exc())) return def time_str_to_int(time_str): """分秒字符串转int""" minutes, seconds = map(int, time_str.split(':')) return minutes * 60 + seconds def time_int_to_str(total_seconds): """int转分秒字符串""" if 0 == total_seconds: return '00:00' minutes, seconds = divmod(total_seconds, 60) return f"{minutes:02d}:{seconds:02d}"