293 lines
8.3 KiB
Go
293 lines
8.3 KiB
Go
package wechat
|
||
|
||
import (
|
||
"context"
|
||
"crypto/hmac"
|
||
"crypto/md5"
|
||
"crypto/sha256"
|
||
"crypto/tls"
|
||
"encoding/hex"
|
||
"encoding/pem"
|
||
"encoding/xml"
|
||
"errors"
|
||
"fmt"
|
||
"hash"
|
||
"io/ioutil"
|
||
"strings"
|
||
|
||
"github.com/go-pay/gopay"
|
||
"github.com/go-pay/gopay/pkg/util"
|
||
"github.com/go-pay/gopay/pkg/xhttp"
|
||
"github.com/go-pay/gopay/pkg/xlog"
|
||
"golang.org/x/crypto/pkcs12"
|
||
)
|
||
|
||
type Country int
|
||
|
||
// 设置支付国家(默认:中国国内)
|
||
// 根据支付地区情况设置国家
|
||
// country:<China:中国国内,China2:中国国内(冗灾方案),SoutheastAsia:东南亚,Other:其他国家>
|
||
func (w *Client) SetCountry(country Country) (client *Client) {
|
||
w.mu.Lock()
|
||
switch country {
|
||
case China:
|
||
w.BaseURL = baseUrlCh
|
||
case China2:
|
||
w.BaseURL = baseUrlCh2
|
||
case SoutheastAsia:
|
||
w.BaseURL = baseUrlHk
|
||
case Other:
|
||
w.BaseURL = baseUrlUs
|
||
default:
|
||
w.BaseURL = baseUrlCh
|
||
}
|
||
w.mu.Unlock()
|
||
return w
|
||
}
|
||
|
||
// 添加微信pem证书文件路径
|
||
// certFilePath:apiclient_cert.pem 文件路径
|
||
// keyFilePath:apiclient_key.pem 文件路径
|
||
func (w *Client) AddCertPemFilePath(certFilePath, keyFilePath string) (err error) {
|
||
return w.addCertFileContentOrPath(certFilePath, keyFilePath, nil)
|
||
}
|
||
|
||
// 添加微信pkcs12证书文件路径
|
||
// pkcs12FilePath:apiclient_cert.p12 文件路径
|
||
func (w *Client) AddCertPkcs12FilePath(pkcs12FilePath string) (err error) {
|
||
return w.addCertFileContentOrPath(nil, nil, pkcs12FilePath)
|
||
}
|
||
|
||
// 添加微信pem证书内容[]byte
|
||
// certFileContent:apiclient_cert.pem 证书内容[]byte
|
||
// keyFileContent:apiclient_key.pem 证书内容[]byte
|
||
func (w *Client) AddCertPemFileContent(certFileContent, keyFileContent []byte) (err error) {
|
||
return w.addCertFileContentOrPath(certFileContent, keyFileContent, nil)
|
||
}
|
||
|
||
// 添加微信pkcs12证书内容[]byte
|
||
// p12FileContent:apiclient_cert.p12 证书内容[]byte
|
||
func (w *Client) AddCertPkcs12FileContent(p12FileContent []byte) (err error) {
|
||
return w.addCertFileContentOrPath(nil, nil, p12FileContent)
|
||
}
|
||
|
||
// 添加微信证书文件 Path 路径或证书内容
|
||
// 注意:只传pem证书或只传pkcs12证书均可,无需3个证书全传
|
||
func (w *Client) addCertFileContentOrPath(certFile, keyFile, pkcs12File interface{}) (err error) {
|
||
if err = checkCertFilePathOrContent(certFile, keyFile, pkcs12File); err != nil {
|
||
return
|
||
}
|
||
var config *tls.Config
|
||
if config, err = w.addCertConfig(certFile, keyFile, pkcs12File); err != nil {
|
||
return
|
||
}
|
||
w.mu.Lock()
|
||
w.Certificate = &config.Certificates[0]
|
||
w.mu.Unlock()
|
||
return
|
||
}
|
||
|
||
func (w *Client) addCertConfig(certFile, keyFile, pkcs12File interface{}) (tlsConfig *tls.Config, err error) {
|
||
if certFile == nil && keyFile == nil && pkcs12File == nil {
|
||
w.mu.RLock()
|
||
defer w.mu.RUnlock()
|
||
if w.Certificate != nil {
|
||
tlsConfig = &tls.Config{
|
||
Certificates: []tls.Certificate{*w.Certificate},
|
||
InsecureSkipVerify: true,
|
||
}
|
||
return tlsConfig, nil
|
||
}
|
||
return nil, errors.New("cert parse failed or nil")
|
||
}
|
||
|
||
var (
|
||
certPem, keyPem []byte
|
||
certificate tls.Certificate
|
||
)
|
||
if certFile != nil && keyFile != nil {
|
||
if _, ok := certFile.([]byte); ok {
|
||
certPem = certFile.([]byte)
|
||
} else {
|
||
certPem, err = ioutil.ReadFile(certFile.(string))
|
||
}
|
||
if _, ok := keyFile.([]byte); ok {
|
||
keyPem = keyFile.([]byte)
|
||
} else {
|
||
keyPem, err = ioutil.ReadFile(keyFile.(string))
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("ioutil.ReadFile:%w", err)
|
||
}
|
||
} else if pkcs12File != nil {
|
||
var pfxData []byte
|
||
if _, ok := pkcs12File.([]byte); ok {
|
||
pfxData = pkcs12File.([]byte)
|
||
} else {
|
||
if pfxData, err = ioutil.ReadFile(pkcs12File.(string)); err != nil {
|
||
return nil, fmt.Errorf("ioutil.ReadFile:%w", err)
|
||
}
|
||
}
|
||
blocks, err := pkcs12.ToPEM(pfxData, w.MchId)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("pkcs12.ToPEM:%w", err)
|
||
}
|
||
for _, b := range blocks {
|
||
keyPem = append(keyPem, pem.EncodeToMemory(b)...)
|
||
}
|
||
certPem = keyPem
|
||
}
|
||
if certPem != nil && keyPem != nil {
|
||
if certificate, err = tls.X509KeyPair(certPem, keyPem); err != nil {
|
||
return nil, fmt.Errorf("tls.LoadX509KeyPair:%w", err)
|
||
}
|
||
tlsConfig = &tls.Config{
|
||
Certificates: []tls.Certificate{certificate},
|
||
InsecureSkipVerify: true,
|
||
}
|
||
return tlsConfig, nil
|
||
}
|
||
return nil, errors.New("cert files must all nil or all not nil")
|
||
}
|
||
|
||
func checkCertFilePathOrContent(certFile, keyFile, pkcs12File interface{}) error {
|
||
if certFile == nil && keyFile == nil && pkcs12File == nil {
|
||
return nil
|
||
}
|
||
if certFile != nil && keyFile != nil {
|
||
files := map[string]interface{}{"certFile": certFile, "keyFile": keyFile}
|
||
for varName, v := range files {
|
||
switch v := v.(type) {
|
||
case string:
|
||
if v == util.NULL {
|
||
return fmt.Errorf("%s is empty", varName)
|
||
}
|
||
case []byte:
|
||
if len(v) == 0 {
|
||
return fmt.Errorf("%s is empty", varName)
|
||
}
|
||
default:
|
||
return fmt.Errorf("%s type error", varName)
|
||
}
|
||
}
|
||
return nil
|
||
} else if pkcs12File != nil {
|
||
switch pkcs12File := pkcs12File.(type) {
|
||
case string:
|
||
if pkcs12File == util.NULL {
|
||
return errors.New("pkcs12File is empty")
|
||
}
|
||
case []byte:
|
||
if len(pkcs12File) == 0 {
|
||
return errors.New("pkcs12File is empty")
|
||
}
|
||
default:
|
||
return errors.New("pkcs12File type error")
|
||
}
|
||
return nil
|
||
} else {
|
||
return errors.New("certFile keyFile must all nil or all not nil")
|
||
}
|
||
}
|
||
|
||
// 获取微信支付正式环境Sign值
|
||
func GetReleaseSign(apiKey string, signType string, bm gopay.BodyMap) (sign string) {
|
||
var h hash.Hash
|
||
if signType == SignType_HMAC_SHA256 {
|
||
h = hmac.New(sha256.New, []byte(apiKey))
|
||
} else {
|
||
h = md5.New()
|
||
}
|
||
h.Write([]byte(bm.EncodeWeChatSignParams(apiKey)))
|
||
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
|
||
}
|
||
|
||
// 获取微信支付正式环境Sign值
|
||
func (w *Client) getReleaseSign(apiKey string, signType string, bm gopay.BodyMap) (sign string) {
|
||
var h hash.Hash
|
||
if signType == SignType_HMAC_SHA256 {
|
||
h = hmac.New(sha256.New, []byte(apiKey))
|
||
} else {
|
||
h = md5.New()
|
||
}
|
||
signParams := bm.EncodeWeChatSignParams(apiKey)
|
||
if w.DebugSwitch == gopay.DebugOn {
|
||
xlog.Debugf("Wechat_Request_SignStr: %s", signParams)
|
||
}
|
||
h.Write([]byte(signParams))
|
||
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
|
||
}
|
||
|
||
// 获取微信支付沙箱环境Sign值
|
||
func GetSandBoxSign(ctx context.Context, mchId, apiKey string, bm gopay.BodyMap) (sign string, err error) {
|
||
var (
|
||
sandBoxApiKey string
|
||
h hash.Hash
|
||
)
|
||
if sandBoxApiKey, err = getSanBoxKey(ctx, mchId, util.RandomString(32), apiKey, SignType_MD5); err != nil {
|
||
return
|
||
}
|
||
h = md5.New()
|
||
h.Write([]byte(bm.EncodeWeChatSignParams(sandBoxApiKey)))
|
||
sign = strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
|
||
return
|
||
}
|
||
|
||
// 获取微信支付沙箱环境Sign值
|
||
func (w *Client) getSandBoxSign(ctx context.Context, mchId, apiKey string, bm gopay.BodyMap) (sign string, err error) {
|
||
var (
|
||
sandBoxApiKey string
|
||
h hash.Hash
|
||
)
|
||
if sandBoxApiKey, err = getSanBoxKey(ctx, mchId, util.RandomString(32), apiKey, SignType_MD5); err != nil {
|
||
return
|
||
}
|
||
h = md5.New()
|
||
signParams := bm.EncodeWeChatSignParams(sandBoxApiKey)
|
||
if w.DebugSwitch == gopay.DebugOn {
|
||
xlog.Debugf("Wechat_Request_SignStr: %s", signParams)
|
||
}
|
||
h.Write([]byte(signParams))
|
||
sign = strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
|
||
return
|
||
}
|
||
|
||
// 从微信提供的接口获取:SandboxSignKey
|
||
func getSanBoxKey(ctx context.Context, mchId, nonceStr, apiKey, signType string) (key string, err error) {
|
||
bm := make(gopay.BodyMap)
|
||
bm.Set("mch_id", mchId)
|
||
bm.Set("nonce_str", nonceStr)
|
||
// 沙箱环境:获取沙箱环境ApiKey
|
||
if key, err = getSanBoxSignKey(ctx, mchId, nonceStr, GetReleaseSign(apiKey, signType, bm)); err != nil {
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
// 从微信提供的接口获取:SandboxSignKey
|
||
func getSanBoxSignKey(ctx context.Context, mchId, nonceStr, sign string) (key string, err error) {
|
||
reqs := make(gopay.BodyMap)
|
||
reqs.Set("mch_id", mchId)
|
||
reqs.Set("nonce_str", nonceStr)
|
||
reqs.Set("sign", sign)
|
||
|
||
keyResponse := new(getSignKeyResponse)
|
||
_, err = xhttp.NewClient().Type(xhttp.TypeXML).Post(sandboxGetSignKey).SendString(GenerateXml(reqs)).EndStruct(ctx, keyResponse)
|
||
if err != nil {
|
||
return util.NULL, err
|
||
}
|
||
if keyResponse.ReturnCode == "FAIL" {
|
||
return util.NULL, errors.New(keyResponse.ReturnMsg)
|
||
}
|
||
return keyResponse.SandboxSignkey, nil
|
||
}
|
||
|
||
// 生成请求XML的Body体
|
||
func GenerateXml(bm gopay.BodyMap) (reqXml string) {
|
||
bs, err := xml.Marshal(bm)
|
||
if err != nil {
|
||
return util.NULL
|
||
}
|
||
return string(bs)
|
||
}
|