service/vendor/github.com/go-pay/gopay/wechat/param.go

293 lines
8.3 KiB
Go
Raw Normal View History

2024-02-20 23:05:24 +08:00
package wechat
import (
"context"
"crypto/hmac"
"crypto/md5"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/pem"
"encoding/xml"
"errors"
"fmt"
"hash"
"io/ioutil"
"strings"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/pkg/util"
"github.com/go-pay/gopay/pkg/xhttp"
"github.com/go-pay/gopay/pkg/xlog"
"golang.org/x/crypto/pkcs12"
)
type Country int
// 设置支付国家(默认:中国国内)
// 根据支付地区情况设置国家
// country<China中国国内China2中国国内冗灾方案SoutheastAsia东南亚Other其他国家>
func (w *Client) SetCountry(country Country) (client *Client) {
w.mu.Lock()
switch country {
case China:
w.BaseURL = baseUrlCh
case China2:
w.BaseURL = baseUrlCh2
case SoutheastAsia:
w.BaseURL = baseUrlHk
case Other:
w.BaseURL = baseUrlUs
default:
w.BaseURL = baseUrlCh
}
w.mu.Unlock()
return w
}
// 添加微信pem证书文件路径
// certFilePathapiclient_cert.pem 文件路径
// keyFilePathapiclient_key.pem 文件路径
func (w *Client) AddCertPemFilePath(certFilePath, keyFilePath string) (err error) {
return w.addCertFileContentOrPath(certFilePath, keyFilePath, nil)
}
// 添加微信pkcs12证书文件路径
// pkcs12FilePathapiclient_cert.p12 文件路径
func (w *Client) AddCertPkcs12FilePath(pkcs12FilePath string) (err error) {
return w.addCertFileContentOrPath(nil, nil, pkcs12FilePath)
}
// 添加微信pem证书内容[]byte
// certFileContentapiclient_cert.pem 证书内容[]byte
// keyFileContentapiclient_key.pem 证书内容[]byte
func (w *Client) AddCertPemFileContent(certFileContent, keyFileContent []byte) (err error) {
return w.addCertFileContentOrPath(certFileContent, keyFileContent, nil)
}
// 添加微信pkcs12证书内容[]byte
// p12FileContentapiclient_cert.p12 证书内容[]byte
func (w *Client) AddCertPkcs12FileContent(p12FileContent []byte) (err error) {
return w.addCertFileContentOrPath(nil, nil, p12FileContent)
}
// 添加微信证书文件 Path 路径或证书内容
// 注意只传pem证书或只传pkcs12证书均可无需3个证书全传
func (w *Client) addCertFileContentOrPath(certFile, keyFile, pkcs12File interface{}) (err error) {
if err = checkCertFilePathOrContent(certFile, keyFile, pkcs12File); err != nil {
return
}
var config *tls.Config
if config, err = w.addCertConfig(certFile, keyFile, pkcs12File); err != nil {
return
}
w.mu.Lock()
w.Certificate = &config.Certificates[0]
w.mu.Unlock()
return
}
func (w *Client) addCertConfig(certFile, keyFile, pkcs12File interface{}) (tlsConfig *tls.Config, err error) {
if certFile == nil && keyFile == nil && pkcs12File == nil {
w.mu.RLock()
defer w.mu.RUnlock()
if w.Certificate != nil {
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*w.Certificate},
InsecureSkipVerify: true,
}
return tlsConfig, nil
}
return nil, errors.New("cert parse failed or nil")
}
var (
certPem, keyPem []byte
certificate tls.Certificate
)
if certFile != nil && keyFile != nil {
if _, ok := certFile.([]byte); ok {
certPem = certFile.([]byte)
} else {
certPem, err = ioutil.ReadFile(certFile.(string))
}
if _, ok := keyFile.([]byte); ok {
keyPem = keyFile.([]byte)
} else {
keyPem, err = ioutil.ReadFile(keyFile.(string))
}
if err != nil {
return nil, fmt.Errorf("ioutil.ReadFile%w", err)
}
} else if pkcs12File != nil {
var pfxData []byte
if _, ok := pkcs12File.([]byte); ok {
pfxData = pkcs12File.([]byte)
} else {
if pfxData, err = ioutil.ReadFile(pkcs12File.(string)); err != nil {
return nil, fmt.Errorf("ioutil.ReadFile%w", err)
}
}
blocks, err := pkcs12.ToPEM(pfxData, w.MchId)
if err != nil {
return nil, fmt.Errorf("pkcs12.ToPEM%w", err)
}
for _, b := range blocks {
keyPem = append(keyPem, pem.EncodeToMemory(b)...)
}
certPem = keyPem
}
if certPem != nil && keyPem != nil {
if certificate, err = tls.X509KeyPair(certPem, keyPem); err != nil {
return nil, fmt.Errorf("tls.LoadX509KeyPair%w", err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: true,
}
return tlsConfig, nil
}
return nil, errors.New("cert files must all nil or all not nil")
}
func checkCertFilePathOrContent(certFile, keyFile, pkcs12File interface{}) error {
if certFile == nil && keyFile == nil && pkcs12File == nil {
return nil
}
if certFile != nil && keyFile != nil {
files := map[string]interface{}{"certFile": certFile, "keyFile": keyFile}
for varName, v := range files {
switch v := v.(type) {
case string:
if v == util.NULL {
return fmt.Errorf("%s is empty", varName)
}
case []byte:
if len(v) == 0 {
return fmt.Errorf("%s is empty", varName)
}
default:
return fmt.Errorf("%s type error", varName)
}
}
return nil
} else if pkcs12File != nil {
switch pkcs12File := pkcs12File.(type) {
case string:
if pkcs12File == util.NULL {
return errors.New("pkcs12File is empty")
}
case []byte:
if len(pkcs12File) == 0 {
return errors.New("pkcs12File is empty")
}
default:
return errors.New("pkcs12File type error")
}
return nil
} else {
return errors.New("certFile keyFile must all nil or all not nil")
}
}
// 获取微信支付正式环境Sign值
func GetReleaseSign(apiKey string, signType string, bm gopay.BodyMap) (sign string) {
var h hash.Hash
if signType == SignType_HMAC_SHA256 {
h = hmac.New(sha256.New, []byte(apiKey))
} else {
h = md5.New()
}
h.Write([]byte(bm.EncodeWeChatSignParams(apiKey)))
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
}
// 获取微信支付正式环境Sign值
func (w *Client) getReleaseSign(apiKey string, signType string, bm gopay.BodyMap) (sign string) {
var h hash.Hash
if signType == SignType_HMAC_SHA256 {
h = hmac.New(sha256.New, []byte(apiKey))
} else {
h = md5.New()
}
signParams := bm.EncodeWeChatSignParams(apiKey)
if w.DebugSwitch == gopay.DebugOn {
xlog.Debugf("Wechat_Request_SignStr: %s", signParams)
}
h.Write([]byte(signParams))
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
}
// 获取微信支付沙箱环境Sign值
func GetSandBoxSign(ctx context.Context, mchId, apiKey string, bm gopay.BodyMap) (sign string, err error) {
var (
sandBoxApiKey string
h hash.Hash
)
if sandBoxApiKey, err = getSanBoxKey(ctx, mchId, util.RandomString(32), apiKey, SignType_MD5); err != nil {
return
}
h = md5.New()
h.Write([]byte(bm.EncodeWeChatSignParams(sandBoxApiKey)))
sign = strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
return
}
// 获取微信支付沙箱环境Sign值
func (w *Client) getSandBoxSign(ctx context.Context, mchId, apiKey string, bm gopay.BodyMap) (sign string, err error) {
var (
sandBoxApiKey string
h hash.Hash
)
if sandBoxApiKey, err = getSanBoxKey(ctx, mchId, util.RandomString(32), apiKey, SignType_MD5); err != nil {
return
}
h = md5.New()
signParams := bm.EncodeWeChatSignParams(sandBoxApiKey)
if w.DebugSwitch == gopay.DebugOn {
xlog.Debugf("Wechat_Request_SignStr: %s", signParams)
}
h.Write([]byte(signParams))
sign = strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
return
}
// 从微信提供的接口获取SandboxSignKey
func getSanBoxKey(ctx context.Context, mchId, nonceStr, apiKey, signType string) (key string, err error) {
bm := make(gopay.BodyMap)
bm.Set("mch_id", mchId)
bm.Set("nonce_str", nonceStr)
// 沙箱环境获取沙箱环境ApiKey
if key, err = getSanBoxSignKey(ctx, mchId, nonceStr, GetReleaseSign(apiKey, signType, bm)); err != nil {
return
}
return
}
// 从微信提供的接口获取SandboxSignKey
func getSanBoxSignKey(ctx context.Context, mchId, nonceStr, sign string) (key string, err error) {
reqs := make(gopay.BodyMap)
reqs.Set("mch_id", mchId)
reqs.Set("nonce_str", nonceStr)
reqs.Set("sign", sign)
keyResponse := new(getSignKeyResponse)
_, err = xhttp.NewClient().Type(xhttp.TypeXML).Post(sandboxGetSignKey).SendString(GenerateXml(reqs)).EndStruct(ctx, keyResponse)
if err != nil {
return util.NULL, err
}
if keyResponse.ReturnCode == "FAIL" {
return util.NULL, errors.New(keyResponse.ReturnMsg)
}
return keyResponse.SandboxSignkey, nil
}
// 生成请求XML的Body体
func GenerateXml(bm gopay.BodyMap) (reqXml string) {
bs, err := xml.Marshal(bm)
if err != nil {
return util.NULL
}
return string(bs)
}