core.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. from datetime import datetime
  5. from django.conf import settings
  6. import requests
  7. from .type import RequestType, SignType
  8. from .utils import aes_decrypt, build_authorization, hmac_sign, load_certificate, load_private_key, rsa_decrypt, rsa_encrypt, rsa_sign, rsa_verify
  9. class Core(object):
  10. def __init__(self, mchid, cert_serial_no, apiv3_key, logger=None, proxy=None):
  11. self._mchid = mchid # 商户号
  12. self._cert_serial_no = cert_serial_no # 商户证书序列号
  13. self._private_key = self._innt_private_key() # 商户证书私钥
  14. self._apiv3_key = apiv3_key # 商户APIv3密钥
  15. self._gate_way = 'https://api.mch.weixin.qq.com'
  16. self._certificates = []
  17. self._cert_dir = settings.PUBLIC_CERT_ROOT # 平台证书存放目录(减少证书下载调用次数)
  18. self._logger = logger # 日志记录器
  19. self._proxy = proxy # 代理设置
  20. self._init_certificates()
  21. def _innt_private_key(self):
  22. '''
  23. 加载商户私钥 PRIVATE_CERT_ROOT是在settings里边设置商户私钥证书地址
  24. :return:
  25. '''
  26. os_path = settings.PRIVATE_CERT_ROOT + "apiclient_key.pem"
  27. with open(os_path) as f:
  28. private_key = f.read()
  29. return load_private_key(private_key)
  30. def _init_certificates(self):
  31. '''
  32. 初始化平台证书 如果存在平台证书就加载平台证书 如果没有平台证书或者平台证书过期就下载平台证书
  33. :return:
  34. '''
  35. if self._cert_dir and os.path.exists(self._cert_dir):
  36. for file_name in os.listdir(self._cert_dir):
  37. if not file_name.lower().endswith('.pem'):
  38. continue
  39. with open(self._cert_dir + file_name, encoding="utf-8") as f:
  40. certificate = load_certificate(f.read())
  41. now = datetime.utcnow()
  42. if certificate and now >= certificate.not_valid_before and now <= certificate.not_valid_after:
  43. self._certificates.append(certificate)
  44. if not self._certificates:
  45. self._update_certificates()
  46. if not self._certificates:
  47. raise Exception('未发现平台证书,请仔细检查您的初始化参数!')
  48. def _update_certificates(self):
  49. '''
  50. 下载平台证书
  51. :return:
  52. '''
  53. path = '/v3/certificates'
  54. self._certificates.clear()
  55. code, message = self.request(path, skip_verify=True)
  56. if code != 200:
  57. return
  58. data = json.loads(message).get('data')
  59. for value in data:
  60. serial_no = value.get('serial_no')
  61. effective_time = value.get('effective_time')
  62. expire_time = value.get('expire_time')
  63. encrypt_certificate = value.get('encrypt_certificate')
  64. algorithm = nonce = associated_data = ciphertext = None
  65. if encrypt_certificate:
  66. algorithm = encrypt_certificate.get('algorithm')
  67. nonce = encrypt_certificate.get('nonce')
  68. associated_data = encrypt_certificate.get('associated_data')
  69. ciphertext = encrypt_certificate.get('ciphertext')
  70. if not (serial_no and effective_time and expire_time and algorithm and nonce and associated_data and ciphertext):
  71. continue
  72. cert_str = aes_decrypt(nonce=nonce, ciphertext=ciphertext, associated_data=associated_data, apiv3_key=self._apiv3_key)
  73. certificate = load_certificate(cert_str)
  74. if not certificate:
  75. continue
  76. now = datetime.utcnow()
  77. if now < certificate.not_valid_before or now > certificate.not_valid_after:
  78. continue
  79. self._certificates.append(certificate)
  80. if not self._cert_dir:
  81. continue
  82. if not os.path.exists(self._cert_dir):
  83. os.makedirs(self._cert_dir)
  84. if not os.path.exists(self._cert_dir + serial_no + '.pem'):
  85. with open(self._cert_dir + serial_no + '.pem', 'w') as f:
  86. f.write(cert_str)
  87. def _verify_signature(self, headers, body):
  88. signature = headers.get('Wechatpay-Signature')
  89. timestamp = headers.get('Wechatpay-Timestamp')
  90. nonce = headers.get('Wechatpay-Nonce')
  91. serial_no = headers.get('Wechatpay-Serial')
  92. cert_found = False
  93. for cert in self._certificates:
  94. if int('0x' + serial_no, 16) == cert.serial_number:
  95. cert_found = True
  96. certificate = cert
  97. break
  98. if not cert_found:
  99. self._update_certificates()
  100. for cert in self._certificates:
  101. if int('0x' + serial_no, 16) == cert.serial_number:
  102. cert_found = True
  103. certificate = cert
  104. break
  105. if not cert_found:
  106. return False
  107. if not rsa_verify(timestamp, nonce, body, signature, certificate):
  108. return False
  109. return True
  110. def request(self, path, method=RequestType.GET, data=None, skip_verify=False, sign_data=None, files=None, cipher_data=False, headers={}):
  111. headers.update({'Content-Type': 'application/json'})
  112. if files:
  113. headers['Content-Type'] = 'multipart/form-data'
  114. headers.update({'Accept': 'application/json'})
  115. if cipher_data:
  116. headers.update({'Wechatpay-Serial': hex(self._last_certificate().serial_number)[2:].upper()})
  117. authorization = build_authorization(path, method.value, self._mchid, self._cert_serial_no, self._private_key, data=sign_data if sign_data else data)
  118. headers.update({'Authorization': authorization})
  119. if method == RequestType.GET:
  120. response = requests.get(url=self._gate_way + path, headers=headers, proxies=self._proxy)
  121. elif method == RequestType.POST:
  122. response = requests.post(url=self._gate_way + path, json=None if files else data, data=data if files else None, headers=headers, files=files, proxies=self._proxy)
  123. elif method == RequestType.PATCH:
  124. response = requests.patch(url=self._gate_way + path, json=data, headers=headers, proxies=self._proxy)
  125. elif method == RequestType.PUT:
  126. response = requests.put(url=self._gate_way + path, json=data, headers=headers, proxies=self._proxy)
  127. elif method == RequestType.DELETE:
  128. response = requests.delete(url=self._gate_way + path, headers=headers, proxies=self._proxy)
  129. else:
  130. raise Exception('请求类型不被支持!')
  131. if response.status_code in range(200, 300) and not skip_verify:
  132. if not self._verify_signature(response.headers, response.text):
  133. raise Exception('验证签名失败!')
  134. return response.status_code, response.text if 'application/json' in response.headers.get('Content-Type') else response.content
  135. def sign(self, data, sign_type=SignType.RSA_SHA256):
  136. if sign_type == SignType.RSA_SHA256:
  137. sign_str = '\n'.join(data) + '\n'
  138. return rsa_sign(self._private_key, sign_str)
  139. elif sign_type == SignType.HMAC_SHA256:
  140. key_list = sorted(data.keys())
  141. sign_str = ''
  142. for k in key_list:
  143. v = data[k]
  144. sign_str += str(k) + '=' + str(v) + '&'
  145. sign_str += 'key=' + self._apiv3_key
  146. return hmac_sign(self._apiv3_key, sign_str)
  147. else:
  148. raise Exception('错误的签名类型!')
  149. def decrypt_callback(self, headers, body):
  150. if isinstance(body, bytes):
  151. body = body.decode('UTF-8')
  152. if not self._verify_signature(headers, body):
  153. return None
  154. data = json.loads(body)
  155. resource_type = data.get('resource_type')
  156. if resource_type != 'encrypt-resource':
  157. return None
  158. resource = data.get('resource')
  159. if not resource:
  160. return None
  161. algorithm = resource.get('algorithm')
  162. if algorithm != 'AEAD_AES_256_GCM':
  163. raise Exception('该算法不被支持!')
  164. nonce = resource.get('nonce')
  165. ciphertext = resource.get('ciphertext')
  166. associated_data = resource.get('associated_data')
  167. if not (nonce and ciphertext):
  168. return None
  169. if not associated_data:
  170. associated_data = ''
  171. result = aes_decrypt(
  172. nonce=nonce,
  173. ciphertext=ciphertext,
  174. associated_data=associated_data,
  175. apiv3_key=self._apiv3_key)
  176. return result
  177. def callback(self, headers, body):
  178. if isinstance(body, bytes):
  179. body = body.decode('UTF-8')
  180. result = self.decrypt_callback(headers=headers, body=body)
  181. if result:
  182. data = json.loads(body)
  183. data.update({'resource': json.loads(result)})
  184. return data
  185. else:
  186. return result
  187. def decrypt(self, ciphtext):
  188. return rsa_decrypt(ciphertext=ciphtext, private_key=self._private_key)
  189. def encrypt(self, text):
  190. return rsa_encrypt(text=text, certificate=self._last_certificate())
  191. def _last_certificate(self):
  192. if not self._certificates:
  193. self._update_certificates()
  194. certificate = self._certificates[0]
  195. for cert in self._certificates:
  196. if certificate.not_valid_after < cert.not_valid_after:
  197. certificate = cert
  198. return certificate