xframe/vendor/github.com/alicebob/miniredis/v2/cmd_scripting.go

344 lines
7.1 KiB
Go

package miniredis
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
"sync"
luajson "github.com/alicebob/gopher-json"
lua "github.com/yuin/gopher-lua"
"github.com/yuin/gopher-lua/parse"
"github.com/alicebob/miniredis/v2/server"
)
func commandsScripting(m *Miniredis) {
m.srv.Register("EVAL", m.cmdEval)
m.srv.Register("EVALSHA", m.cmdEvalsha)
m.srv.Register("SCRIPT", m.cmdScript)
}
var (
parsedScripts = sync.Map{}
)
// Execute lua. Needs to run m.Lock()ed, from within withTx().
// Returns true if the lua was OK (and hence should be cached).
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
l := lua.NewState(lua.Options{SkipOpenLibs: true})
defer l.Close()
// Taken from the go-lua manual
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage},
{lua.BaseLibName, lua.OpenBase},
{lua.CoroutineLibName, lua.OpenCoroutine},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
{lua.DebugLibName, lua.OpenDebug},
} {
if err := l.CallByParam(lua.P{
Fn: l.NewFunction(pair.f),
NRet: 0,
Protect: true,
}, lua.LString(pair.n)); err != nil {
panic(err)
}
}
luajson.Preload(l)
requireGlobal(l, "cjson", "json")
// set global variable KEYS
keysTable := l.NewTable()
keysS, args := args[0], args[1:]
keysLen, err := strconv.Atoi(keysS)
if err != nil {
c.WriteError(msgInvalidInt)
return false
}
if keysLen < 0 {
c.WriteError(msgNegativeKeysNumber)
return false
}
if keysLen > len(args) {
c.WriteError(msgInvalidKeysNumber)
return false
}
keys, args := args[:keysLen], args[keysLen:]
for i, k := range keys {
l.RawSet(keysTable, lua.LNumber(i+1), lua.LString(k))
}
l.SetGlobal("KEYS", keysTable)
argvTable := l.NewTable()
for i, a := range args {
l.RawSet(argvTable, lua.LNumber(i+1), lua.LString(a))
}
l.SetGlobal("ARGV", argvTable)
redisFuncs, redisConstants := mkLua(m.srv, c, sha)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
for k, v := range redisConstants {
mod.RawSetString(k, v)
}
l.Push(mod)
return 1
}))
_ = doScript(l, protectGlobals)
l.Push(lua.LString("redis"))
l.Call(1, 0)
if err := doScript(l, script); err != nil {
c.WriteError(err.Error())
return false
}
luaToRedis(l, c, l.Get(1))
return true
}
// doScript pre-compiiles the given script into a Lua prototype,
// then executes the pre-compiled function against the given lua state.
//
// This is thread-safe.
func doScript(l *lua.LState, script string) error {
proto, err := compile(script)
if err != nil {
return fmt.Errorf(errLuaParseError(err))
}
lfunc := l.NewFunctionFromProto(proto)
l.Push(lfunc)
if err := l.PCall(0, lua.MultRet, nil); err != nil {
// ensure we wrap with the correct format.
return fmt.Errorf(errLuaParseError(err))
}
return nil
}
func compile(script string) (*lua.FunctionProto, error) {
if val, ok := parsedScripts.Load(script); ok {
return val.(*lua.FunctionProto), nil
}
chunk, err := parse.Parse(strings.NewReader(script), "<string>")
if err != nil {
return nil, err
}
proto, err := lua.Compile(chunk, "")
if err != nil {
return nil, err
}
parsedScripts.Store(script, proto)
return proto, nil
}
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
script, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sha := sha1Hex(script)
ok := m.runLuaScript(c, sha, script, args)
if ok {
m.scripts[sha] = script
}
})
}
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
sha, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
script, ok := m.scripts[sha]
if !ok {
c.WriteError(msgNoScriptFound)
return
}
m.runLuaScript(c, sha, script, args)
})
}
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
var opts struct {
subcmd string
script string
}
opts.subcmd, args = args[0], args[1:]
switch strings.ToLower(opts.subcmd) {
case "load":
if len(args) != 1 {
setDirty(c)
c.WriteError(fmt.Sprintf(msgFScriptUsage, "LOAD"))
return
}
opts.script = args[0]
case "exists":
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber("script|exists"))
return
}
case "flush":
if len(args) == 1 {
switch strings.ToUpper(args[0]) {
case "SYNC", "ASYNC":
args = args[1:]
default:
}
}
if len(args) != 0 {
setDirty(c)
c.WriteError(msgScriptFlush)
return
}
default:
setDirty(c)
c.WriteError(fmt.Sprintf(msgFScriptUsageSimple, strings.ToUpper(opts.subcmd)))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
switch strings.ToLower(opts.subcmd) {
case "load":
if _, err := parse.Parse(strings.NewReader(opts.script), "user_script"); err != nil {
c.WriteError(errLuaParseError(err))
return
}
sha := sha1Hex(opts.script)
m.scripts[sha] = opts.script
c.WriteBulk(sha)
case "exists":
c.WriteLen(len(args))
for _, arg := range args {
if _, ok := m.scripts[arg]; ok {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
}
case "flush":
m.scripts = map[string]string{}
c.WriteOK()
}
})
}
func sha1Hex(s string) string {
h := sha1.New()
io.WriteString(h, s)
return hex.EncodeToString(h.Sum(nil))
}
// requireGlobal imports module modName into the global namespace with the
// identifier id. panics if an error results from the function execution
func requireGlobal(l *lua.LState, id, modName string) {
if err := l.CallByParam(lua.P{
Fn: l.GetGlobal("require"),
NRet: 1,
Protect: true,
}, lua.LString(modName)); err != nil {
panic(err)
}
mod := l.Get(-1)
l.Pop(1)
l.SetGlobal(id, mod)
}
// the following script protects globals
// it is based on: http://metalua.luaforge.net/src/lib/strict.lua.html
var protectGlobals = `
local dbg=debug
local mt = {}
setmetatable(_G, mt)
mt.__newindex = function (t, n, v)
if dbg.getinfo(2) then
local w = dbg.getinfo(2, "S").what
if w ~= "C" then
error("Script attempted to create global variable '"..tostring(n).."'", 2)
end
end
rawset(t, n, v)
end
mt.__index = function (t, n)
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
end
return rawget(t, n)
end
debug = nil
`