add cache and rename some files

This commit is contained in:
2026-01-15 15:41:19 +01:00
parent 7879f15726
commit b64815d9ab
4753 changed files with 931902 additions and 1 deletions

View File

@@ -0,0 +1,133 @@
---@class vm
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
local simpleSwitch
simpleSwitch = util.switch()
: case 'goto'
: call(function (source, pushResult)
if source.node then
pushResult(source.node)
end
end)
: case 'doc.cast.name'
: call(function (source, pushResult)
local loc = guide.getLocal(source, source[1], source.start)
if loc then
pushResult(loc)
end
end)
: case 'doc.field'
: call(function (source, pushResult)
pushResult(source)
end)
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchBySimple(source, pushResult)
simpleSwitch(source.type, source, pushResult)
end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
local idSources = vm.getVariableSets(source)
if not idSources then
return
end
for _, src in ipairs(idSources) do
pushResult(src)
end
end
local function searchByNode(source, pushResult)
local node = vm.compileNode(source)
local suri = guide.getUri(source)
for n in node:eachObject() do
if n.type == 'global' then
for _, set in ipairs(n:getSets(suri)) do
pushResult(set)
end
else
pushResult(n)
end
end
end
---@param source parser.object
---@return parser.object[]
function vm.getDefs(source)
local results = {}
local mark = {}
local hasLocal
local function pushResult(src)
if src.type == 'local' then
if hasLocal then
return
end
hasLocal = true
if source.type ~= 'local'
and source.type ~= 'getlocal'
and source.type ~= 'setlocal'
and source.type ~= 'doc.cast.name' then
return
end
end
if not mark[src] then
mark[src] = true
if guide.isAssign(src)
or guide.isLiteral(src) then
results[#results+1] = src
end
end
end
searchBySimple(source, pushResult)
searchByLocalID(source, pushResult)
vm.compileByNodeChain(source, pushResult)
searchByNode(source, pushResult)
return results
end
local HAS_DEF_ERR = false -- the error object for comparing
local function checkHasDef(checkFunc, source, pushResult)
local _, err = pcall(checkFunc, source, pushResult)
return err == HAS_DEF_ERR
end
---@param source parser.object
function vm.hasDef(source)
local mark = {}
local hasLocal
local function pushResult(src)
if src.type == 'local' then
if hasLocal then
return
end
hasLocal = true
if source.type ~= 'local'
and source.type ~= 'getlocal'
and source.type ~= 'setlocal'
and source.type ~= 'doc.cast.name' then
return
end
end
if not mark[src] then
mark[src] = true
if guide.isAssign(src)
or guide.isLiteral(src) then
-- break out on 1st result using error() with a unique error object
error(HAS_DEF_ERR)
end
end
end
return checkHasDef(searchBySimple, source, pushResult)
or checkHasDef(searchByLocalID, source, pushResult)
or checkHasDef(vm.compileByNodeChain, source, pushResult)
or checkHasDef(searchByNode, source, pushResult)
end

View File

@@ -0,0 +1,525 @@
local files = require 'files'
local await = require 'await'
local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
local config = require 'config'
---@class parser.object
---@field package _castTargetHead? parser.object | vm.global | false
---@field package _validVersions? table<string, boolean>
---@field package _deprecated? parser.object | false
---@field package _async? boolean
---@field package _nodiscard? boolean
---获取class与alias
---@param suri uri
---@param name? string
---@return parser.object[]
function vm.getDocSets(suri, name)
if name then
local global = vm.getGlobal('type', name)
if not global then
return {}
end
return global:getSets(suri)
else
return vm.getGlobalSets(suri, 'type')
end
end
---@param uri uri
---@return boolean
function vm.isMetaFile(uri)
local status = files.getState(uri)
if not status then
return false
end
local cache = files.getCache(uri)
if not cache then
return false
end
if cache.isMeta ~= nil then
return cache.isMeta
end
cache.isMeta = false
if not status.ast.docs then
return false
end
for _, doc in ipairs(status.ast.docs) do
if doc.type == 'doc.meta' then
cache.isMeta = true
cache.metaName = doc.name
return true
end
end
return false
end
---@param uri uri
---@return string?
function vm.getMetaName(uri)
if not vm.isMetaFile(uri) then
return nil
end
local cache = files.getCache(uri)
if not cache then
return nil
end
if not cache.metaName then
return nil
end
return cache.metaName[1]
end
---@param uri uri
---@return boolean
function vm.isMetaFileRequireable(uri)
if not vm.isMetaFile(uri) then
return false
end
return vm.getMetaName(uri) ~= '_'
end
---@param doc parser.object
---@return table<string, boolean>?
function vm.getValidVersions(doc)
if doc.type ~= 'doc.version' then
return
end
if doc._validVersions then
return doc._validVersions
end
local valids = {
['Lua 5.1'] = false,
['Lua 5.2'] = false,
['Lua 5.3'] = false,
['Lua 5.4'] = false,
['LuaJIT'] = false,
}
for _, version in ipairs(doc.versions) do
if version.ge and type(version.version) == 'number' then
for ver in pairs(valids) do
local verNumber = tonumber(ver:sub(-3))
if verNumber and verNumber >= version.version then
valids[ver] = true
end
end
elseif version.le and type(version.version) == 'number' then
for ver in pairs(valids) do
local verNumber = tonumber(ver:sub(-3))
if verNumber and verNumber <= version.version then
valids[ver] = true
end
end
elseif type(version.version) == 'number' then
valids[('Lua %.1f'):format(version.version)] = true
elseif 'JIT' == version.version then
valids['LuaJIT'] = true
end
end
if valids['Lua 5.1'] then
valids['LuaJIT'] = true
end
doc._validVersions = valids
return valids
end
---@param value parser.object
---@return parser.object?
local function getDeprecated(value)
if not value.bindDocs then
return nil
end
if value._deprecated ~= nil then
return value._deprecated or nil
end
for _, doc in ipairs(value.bindDocs) do
if doc.type == 'doc.deprecated' then
value._deprecated = doc
return doc
elseif doc.type == 'doc.version' then
local valids = vm.getValidVersions(doc)
if valids and not valids[config.get(guide.getUri(value), 'Lua.runtime.version')] then
value._deprecated = doc
return doc
end
end
end
if value.type == 'function' then
local doc = getDeprecated(value.parent)
if doc then
value._deprecated = doc
return doc
end
end
value._deprecated = false
return nil
end
---@param value parser.object
---@param deep boolean?
---@return parser.object?
function vm.getDeprecated(value, deep)
if deep then
local defs = vm.getDefs(value)
if #defs == 0 then
return nil
end
local deprecated
for _, def in ipairs(defs) do
if def.type == 'setglobal'
or def.type == 'setfield'
or def.type == 'setmethod'
or def.type == 'setindex'
or def.type == 'tablefield'
or def.type == 'tableindex' then
deprecated = getDeprecated(def)
if not deprecated then
return nil
end
end
end
return deprecated
else
return getDeprecated(value)
end
end
---@param value parser.object
---@param propagate boolean
---@param deepLevel integer?
---@return boolean
local function isAsync(value, propagate, deepLevel)
if value.type == 'function' then
if value._async ~= nil then --already calculated, directly return
return value._async
end
local asyncCache
if propagate then
asyncCache = vm.getCache 'async.propagate'
local result = asyncCache[value]
if result ~= nil then
return result
end
end
if value.bindDocs then --try parse the annotation
for _, doc in ipairs(value.bindDocs) do
if doc.type == 'doc.async' then
value._async = true
return true
end
end
end
if propagate then -- if enable async propagation, try check calling functions
if deepLevel and deepLevel > 50 then
return false
end
local isAsyncCall = vm.isAsyncCall
local callingAsync = guide.eachSourceType(value, 'call', function (source)
local parent = guide.getParentFunction(source)
if parent ~= value then
return nil
end
local nextLevel = (deepLevel or 1) + 1
local ok = isAsyncCall(source, nextLevel)
if ok then --if any calling function is async, directly return
return ok
end
--if not, try check the next calling function
return nil
end)
if callingAsync then
asyncCache[value] = true
return true
end
asyncCache[value] = false
end
value._async = false
return false
end
if value.type == 'main' then
return true
end
return value.async == true
end
---@param value parser.object
---@param deep boolean?
---@param deepLevel integer?
---@return boolean
function vm.isAsync(value, deep, deepLevel)
local uri = guide.getUri(value)
local propagate = config.get(uri, 'Lua.hint.awaitPropagate')
if isAsync(value, propagate, deepLevel) then
return true
end
if deep then
local defs = vm.getDefs(value)
if #defs == 0 then
return false
end
for _, def in ipairs(defs) do
if isAsync(def, propagate, deepLevel) then
return true
end
end
end
return false
end
---@param value parser.object
---@return boolean
local function isNoDiscard(value)
if value.type == 'function' then
if not value.bindDocs then
return false
end
if value._nodiscard ~= nil then
return value._nodiscard
end
for _, doc in ipairs(value.bindDocs) do
if doc.type == 'doc.nodiscard' then
value._nodiscard = true
return true
end
end
value._nodiscard = false
return false
end
return false
end
---@param value parser.object
---@param deep boolean?
---@return boolean
function vm.isNoDiscard(value, deep)
if isNoDiscard(value) then
return true
end
if deep then
local defs = vm.getDefs(value)
if #defs == 0 then
return false
end
for _, def in ipairs(defs) do
if isNoDiscard(def) then
return true
end
end
end
return false
end
---@param param parser.object
---@return boolean
local function isCalledInFunction(param)
if not param.ref then
return false
end
local func = guide.getParentFunction(param)
for _, ref in ipairs(param.ref) do
if ref.type == 'getlocal' then
if ref.parent.type == 'call'
and guide.getParentFunction(ref) == func then
return true
end
if ref.parent.type == 'callargs'
and ref.parent[1] == ref
and guide.getParentFunction(ref) == func then
if ref.parent.parent.node.special == 'pcall'
or ref.parent.parent.node.special == 'xpcall' then
return true
end
end
end
end
return false
end
---@param node parser.object
---@param index integer
---@return boolean
local function isLinkedCall(node, index)
for _, def in ipairs(vm.getDefs(node)) do
if def.type == 'function' then
local param = def.args and def.args[index]
if param then
if isCalledInFunction(param) then
return true
end
end
end
end
return false
end
---@param node parser.object
---@param index integer
---@return boolean
function vm.isLinkedCall(node, index)
return isLinkedCall(node, index)
end
---@param call parser.object
---@param deepLevel integer?
---@return boolean
function vm.isAsyncCall(call, deepLevel)
if vm.isAsync(call.node, true, deepLevel) then
return true
end
if not call.args then
return false
end
for i, arg in ipairs(call.args) do
if vm.isAsync(arg, true, deepLevel)
and isLinkedCall(call.node, i) then
return true
end
end
return false
end
---@param doc parser.object
---@param results table[]
local function makeDiagRange(doc, results)
local names
if doc.names then
names = {}
for _, nameUnit in ipairs(doc.names) do
local name = nameUnit[1]
names[name] = true
end
end
local row = guide.rowColOf(doc.start)
if doc.mode == 'disable-next-line' then
results[#results+1] = {
mode = 'disable',
names = names,
row = row + 1,
source = doc,
}
results[#results+1] = {
mode = 'enable',
names = names,
row = row + 2,
source = doc,
}
elseif doc.mode == 'disable-line' then
results[#results+1] = {
mode = 'disable',
names = names,
row = row,
source = doc,
}
results[#results+1] = {
mode = 'enable',
names = names,
row = row + 1,
source = doc,
}
elseif doc.mode == 'disable' then
results[#results+1] = {
mode = 'disable',
names = names,
row = row + 1,
source = doc,
}
elseif doc.mode == 'enable' then
results[#results+1] = {
mode = 'enable',
names = names,
row = row + 1,
source = doc,
}
end
end
---@param uri uri
---@param position integer
---@param name string
---@param err? boolean
---@return boolean
function vm.isDiagDisabledAt(uri, position, name, err)
local status = files.getState(uri)
if not status then
return false
end
if not status.ast.docs then
return false
end
local cache = files.getCache(uri)
if not cache then
return false
end
if not cache.diagnosticRanges then
cache.diagnosticRanges = {}
for _, doc in ipairs(status.ast.docs) do
if doc.type == 'doc.diagnostic' then
makeDiagRange(doc, cache.diagnosticRanges)
end
end
table.sort(cache.diagnosticRanges, function (a, b)
return a.row < b.row
end)
end
if #cache.diagnosticRanges == 0 then
return false
end
local myRow = guide.rowColOf(position)
local count = 0
for _, range in ipairs(cache.diagnosticRanges) do
if range.row <= myRow then
if (range.names and range.names[name])
or (not range.names and not err) then
if range.mode == 'disable' then
count = count + 1
elseif range.mode == 'enable' then
count = count - 1
end
end
else
break
end
end
return count > 0
end
---@param doc parser.object
---@return (parser.object | vm.global)?
function vm.getCastTargetHead(doc)
if doc._castTargetHead ~= nil then
return doc._castTargetHead or nil
end
local name = doc.name[1]:match '^[^%.]+'
if not name then
doc._castTargetHead = false
return nil
end
local loc = guide.getLocal(doc, name, doc.start)
if loc then
doc._castTargetHead = loc
return loc
end
local global = vm.getGlobal('variable', name)
if global then
doc._castTargetHead = global
return global
end
return nil
end
---@param doc parser.object
---@param key string
---@return boolean
function vm.docHasAttr(doc, key)
if not doc.docAttr then
return false
end
for _, name in ipairs(doc.docAttr.names) do
if name[1] == key then
return true
end
end
return false
end

View File

@@ -0,0 +1,59 @@
---@class vm
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
local searchByNodeSwitch = util.switch()
: case 'global'
---@param global vm.global
: call(function (suri, global, pushResult)
for _, set in ipairs(global:getSets(suri)) do
pushResult(set)
end
end)
: default(function (_suri, source, pushResult)
pushResult(source)
end)
local function searchByLocalID(source, pushResult)
local fields = vm.getVariableFields(source, true)
if fields then
for _, field in ipairs(fields) do
pushResult(field)
end
end
end
local function searchByNode(source, pushResult, mark)
mark = mark or {}
if mark[source] then
return
end
mark[source] = true
local uri = guide.getUri(source)
vm.compileByParentNode(source, vm.ANY, function (field)
searchByNodeSwitch(field.type, uri, field, pushResult)
end)
vm.compileByNodeChain(source, function (src)
searchByNode(src, pushResult, mark)
end)
end
---@param source parser.object
---@return parser.object[]
function vm.getFields(source)
local results = {}
local mark = {}
local function pushResult(src)
if not mark[src] then
mark[src] = true
results[#results+1] = src
end
end
searchByLocalID(source, pushResult)
searchByNode(source, pushResult)
return results
end

View File

@@ -0,0 +1,533 @@
---@class vm
local vm = require 'vm.vm'
local guide = require 'parser.guide'
local util = require 'utility'
---@param arg parser.object
---@return parser.object?
local function getDocParam(arg)
if not arg.bindDocs then
return nil
end
for _, doc in ipairs(arg.bindDocs) do
if doc.type == 'doc.param'
and doc.param[1] == arg[1] then
return doc
end
end
return nil
end
---@param func parser.object
---@return integer min
---@return number max
---@return integer def
function vm.countParamsOfFunction(func)
local min = 0
local max = 0
local def = 0
if func.type == 'function' then
if func.args then
max = #func.args
def = max
for i = #func.args, 1, -1 do
local arg = func.args[i]
if arg.type == '...' then
max = math.huge
elseif arg.type == 'self'
and i == 1 then
min = i
break
elseif getDocParam(arg)
and not vm.compileNode(arg):isNullable() then
min = i
break
end
end
end
end
if func.type == 'doc.type.function' then
if func.args then
max = #func.args
def = max
for i = #func.args, 1, -1 do
local arg = func.args[i]
if arg.name and arg.name[1] =='...' then
max = math.huge
elseif not vm.compileNode(arg):isNullable() then
min = i
break
end
end
end
end
return min, max, def
end
---@param source parser.object
---@return integer min
---@return number max
---@return integer def
function vm.countParamsOfSource(source)
local min = 0
local max = 0
local def = 0
local overloads = {}
if source.bindDocs then
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.overload' then
overloads[doc.overload] = true
end
end
end
local hasDocFunction
for nd in vm.compileNode(source):eachObject() do
if nd.type == 'doc.type.function' and not overloads[nd] then
hasDocFunction = true
---@cast nd parser.object
local dmin, dmax, ddef = vm.countParamsOfFunction(nd)
if dmin > min then
min = dmin
end
if dmax > max then
max = dmax
end
if ddef > def then
def = ddef
end
end
end
if not hasDocFunction then
local dmin, dmax, ddef = vm.countParamsOfFunction(source)
if dmin > min then
min = dmin
end
if dmax > max then
max = dmax
end
if ddef > def then
def = ddef
end
end
return min, max, def
end
---@param node vm.node
---@return integer min
---@return number max
---@return integer def
function vm.countParamsOfNode(node)
local min, max, def
for n in node:eachObject() do
if n.type == 'function'
or n.type == 'doc.type.function' then
---@cast n parser.object
local fmin, fmax, fdef = vm.countParamsOfFunction(n)
if not min or fmin < min then
min = fmin
end
if not max or fmax > max then
max = fmax
end
if not def or fdef > def then
def = fdef
end
end
end
return min or 0, max or math.huge, def or 0
end
---@param func parser.object
---@param onlyDoc? boolean
---@param mark? table
---@return integer min
---@return number max
---@return integer def
function vm.countReturnsOfFunction(func, onlyDoc, mark)
if func.type == 'function' then
---@type integer?, number?, integer?
local min, max, def
local hasDocReturn
if func.bindDocs then
local lastReturn
local n = 0
---@type integer?, number?, integer?
local dmin, dmax, ddef
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.return' then
hasDocReturn = true
for _, ret in ipairs(doc.returns) do
n = n + 1
lastReturn = ret
dmax = n
ddef = n
if (not ret.name or ret.name[1] ~= '...')
and not vm.compileNode(ret):isNullable() then
dmin = n
end
end
end
end
if lastReturn then
if lastReturn.name and lastReturn.name[1] == '...' then
dmax = math.huge
end
end
if dmin and (not min or (dmin < min)) then
min = dmin
end
if dmax and (not max or (dmax > max)) then
max = dmax
end
if ddef and (not def or (ddef > def)) then
def = ddef
end
end
if not onlyDoc and not hasDocReturn and func.returns then
for _, ret in ipairs(func.returns) do
local dmin, dmax, ddef = vm.countList(ret, mark)
if not min or dmin < min then
min = dmin
end
if not max or dmax > max then
max = dmax
end
if not def or ddef > def then
def = ddef
end
end
end
return min or 0, max or math.huge, def or 0
end
if func.type == 'doc.type.function' then
return vm.countList(func.returns)
end
error('not a function')
end
---@param source parser.object
---@return integer min
---@return number max
---@return integer def
function vm.countReturnsOfSource(source)
local overloads = {}
local hasDocFunction
local min, max, def
if source.bindDocs then
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.overload' then
overloads[doc.overload] = true
local dmin, dmax, ddef = vm.countReturnsOfFunction(doc.overload)
if not min or dmin < min then
min = dmin
end
if not max or dmax > max then
max = dmax
end
if not def or ddef > def then
def = ddef
end
end
end
end
for nd in vm.compileNode(source):eachObject() do
if nd.type == 'doc.type.function' and not overloads[nd] then
---@cast nd parser.object
hasDocFunction = true
local dmin, dmax, ddef = vm.countReturnsOfFunction(nd)
if not min or dmin < min then
min = dmin
end
if not max or dmax > max then
max = dmax
end
if not def or ddef > def then
def = ddef
end
end
end
if not hasDocFunction then
local dmin, dmax, ddef = vm.countReturnsOfFunction(source, true)
if not min or dmin < min then
min = dmin
end
if not max or dmax > max then
max = dmax
end
if not def or ddef > def then
def = ddef
end
end
return min, max, def
end
---@param func parser.object
---@param mark? table
---@return integer min
---@return number max
---@return integer def
function vm.countReturnsOfCall(func, args, mark)
local funcs = vm.getMatchedFunctions(func, args, mark)
if not funcs then
return 0, math.huge, 0
end
---@type integer?, number?, integer?
local min, max, def
for _, f in ipairs(funcs) do
local rmin, rmax, rdef = vm.countReturnsOfFunction(f, false, mark)
if not min or rmin < min then
min = rmin
end
if not max or rmax > max then
max = rmax
end
if not def or rdef > def then
def = rdef
end
end
return min or 0, max or math.huge, def or 0
end
---@param list parser.object[]?
---@param mark? table
---@return integer min
---@return number max
---@return integer def
function vm.countList(list, mark)
if not list then
return 0, 0, 0
end
local lastArg = list[#list]
if not lastArg then
return 0, 0, 0
end
---@type integer, number, integer
local min, max, def = #list, #list, #list
if lastArg.type == '...'
or lastArg.type == 'varargs'
or (lastArg.type == 'doc.type' and lastArg.name and lastArg.name[1] == '...') then
max = math.huge
elseif lastArg.type == 'call' then
if not mark then
mark = {}
end
if mark[lastArg] then
min = min - 1
max = math.huge
else
mark[lastArg] = true
local rmin, rmax, rdef = vm.countReturnsOfCall(lastArg.node, lastArg.args, mark)
return min - 1 + rmin, max - 1 + rmax, def - 1 + rdef
end
end
for i = min, 1, -1 do
local arg = list[i]
if arg.type == 'doc.type'
and ((arg.name and arg.name[1] == '...')
or vm.compileNode(arg):isNullable()) then
min = i - 1
else
break
end
end
return min, max, def
end
---@param uri uri
---@param args parser.object[]
---@return boolean
local function isAllParamMatched(uri, args, params)
if not params then
return false
end
for i = 1, #args do
if not params[i] then
break
end
local argNode = vm.compileNode(args[i])
local defNode = vm.compileNode(params[i])
if not vm.canCastType(uri, defNode, argNode) then
return false
end
end
return true
end
---@param uri uri
---@param args parser.object[]
---@param func parser.object
---@return number
local function calcFunctionMatchScore(uri, args, func)
if vm.isVarargFunctionWithOverloads(func)
or vm.isFunctionWithOnlyOverloads(func)
or not isAllParamMatched(uri, args, func.args)
then
return -1
end
local matchScore = 0
for i = 1, math.min(#args, #func.args) do
local arg, param = args[i], func.args[i]
local defLiterals, literalsCount = vm.getLiterals(param)
if defLiterals then
for n in vm.compileNode(arg):eachObject() do
-- if param's literals map contains arg's literal, this is narrower than a subtype match
if defLiterals[guide.getLiteral(n)] then
-- the more the literals defined in the param, the less bonus score will be added
-- this favors matching overload param with exact literal value, over alias/enum that has many literal values
matchScore = matchScore + 1/literalsCount
break
end
end
end
end
return matchScore
end
---@param func parser.object
---@param args? parser.object[]
---@return parser.object[]?
function vm.getExactMatchedFunctions(func, args)
local funcs = vm.getMatchedFunctions(func, args)
if not args or not funcs then
return funcs
end
if #funcs == 1 then
return funcs
end
local uri = guide.getUri(func)
local matchScores = {}
for i, n in ipairs(funcs) do
matchScores[i] = calcFunctionMatchScore(uri, args, n)
end
local maxMatchScore = math.max(table.unpack(matchScores))
if maxMatchScore == -1 then
-- all should be removed
return nil
end
local minMatchScore = math.min(table.unpack(matchScores))
if minMatchScore == maxMatchScore then
-- all should be kept
return funcs
end
-- remove functions that have matchScore < maxMatchScore
local needRemove = {}
for i, matchScore in ipairs(matchScores) do
if matchScore < maxMatchScore then
needRemove[#needRemove + 1] = i
end
end
util.tableMultiRemove(funcs, needRemove)
return funcs
end
---@param func parser.object
---@param args? parser.object[]
---@param mark? table
---@return parser.object[]?
function vm.getMatchedFunctions(func, args, mark)
local funcs = {}
local node = vm.compileNode(func)
for n in node:eachObject() do
if n.type == 'function'
or n.type == 'doc.type.function' then
funcs[#funcs+1] = n
end
end
local amin, amax = vm.countList(args, mark)
local matched = {}
for _, n in ipairs(funcs) do
local min, max = vm.countParamsOfFunction(n)
if amin >= min and amax <= max then
matched[#matched+1] = n
end
end
if #matched == 0 then
return nil
else
return matched
end
end
---@param func table
---@return boolean
function vm.isVarargFunctionWithOverloads(func)
if func.type ~= 'function' then
return false
end
if not func.args then
return false
end
if func._varargFunction ~= nil then
return func._varargFunction
end
if func.args[1] and func.args[1].type == 'self' then
if not func.args[2] or func.args[2].type ~= '...' then
func._varargFunction = false
return false
end
else
if not func.args[1] or func.args[1].type ~= '...' then
func._varargFunction = false
return false
end
end
if not func.bindDocs then
func._varargFunction = false
return false
end
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.overload' then
func._varargFunction = true
return true
end
end
func._varargFunction = false
return false
end
---@param func table
---@return boolean
function vm.isFunctionWithOnlyOverloads(func)
if func.type ~= 'function' then
return false
end
if func._onlyOverloadFunction ~= nil then
return func._onlyOverloadFunction
end
if not func.bindDocs then
func._onlyOverloadFunction = false
return false
end
local hasOverload = false
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.overload' then
hasOverload = true
elseif doc.type == 'doc.param'
or doc.type == 'doc.return'
then
-- has specified @param or @return, thus not only @overload
func._onlyOverloadFunction = false
return false
end
end
func._onlyOverloadFunction = hasOverload
return true
end
---@param func parser.object
---@return boolean
function vm.isEmptyFunction(func)
if #func > 0 then
return false
end
local startRow = guide.rowColOf(func.start)
local finishRow = guide.rowColOf(func.finish)
return finishRow - startRow <= 1
end

View File

@@ -0,0 +1,175 @@
---@class vm
local vm = require 'vm.vm'
---@class parser.object
---@field package _generic vm.generic
---@field package _resolved vm.node
---@class vm.generic
---@field sign vm.sign
---@field proto vm.object
local mt = {}
mt.__index = mt
mt.type = 'generic'
---@param source vm.object?
---@param resolved? table<string, vm.node>
---@return vm.object?
local function cloneObject(source, resolved)
if not resolved or not source then
return source
end
if source.type == 'doc.generic.name' then
local key = source[1]
local newName = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
[1] = source[1],
}
if resolved[key] then
vm.setNode(newName, resolved[key], true)
newName._resolved = resolved[key]
end
return newName
end
if source.type == 'doc.type' then
local newType = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
optional = source.optional,
types = {},
}
for i, typeUnit in ipairs(source.types) do
local newObj = cloneObject(typeUnit, resolved)
newType.types[i] = newObj
end
return newType
end
if source.type == 'doc.type.arg' then
local newArg = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
name = source.name,
extends = cloneObject(source.extends, resolved)
}
return newArg
end
if source.type == 'doc.type.array' then
local newArray = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
node = cloneObject(source.node, resolved),
}
return newArray
end
if source.type == 'doc.type.table' then
local newTable = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
fields = {},
}
for i, field in ipairs(source.fields) do
local newField = {
type = field.type,
start = field.start,
finish = field.finish,
parent = newTable,
name = cloneObject(field.name, resolved),
extends = cloneObject(field.extends, resolved),
}
newTable.fields[i] = newField
end
return newTable
end
if source.type == 'doc.type.function' then
local newDocFunc = {
type = source.type,
start = source.start,
finish = source.finish,
parent = source.parent,
args = {},
returns = {},
}
for i, arg in ipairs(source.args) do
local newObj = cloneObject(arg, resolved)
newObj.optional = arg.optional
newDocFunc.args[i] = newObj
end
for i, ret in ipairs(source.returns) do
local newObj = cloneObject(ret, resolved)
newObj.parent = newDocFunc
newObj.optional = ret.optional
newDocFunc.returns[i] = cloneObject(ret, resolved)
end
return newDocFunc
end
return source
end
---@param uri uri
---@param args parser.object
---@return vm.node
function mt:resolve(uri, args)
local resolved = self.sign:resolve(uri, args)
local protoNode = vm.compileNode(self.proto)
local result = vm.createNode()
for nd in protoNode:eachObject() do
if nd.type == 'global' or nd.type == 'variable' then
---@cast nd vm.global | vm.variable
result:merge(nd)
else
---@cast nd -vm.global, -vm.variable
local clonedObject = cloneObject(nd, resolved)
if clonedObject then
local clonedNode = vm.compileNode(clonedObject)
result:merge(clonedNode)
end
end
end
if protoNode:isOptional() then
result:addOptional()
end
return result
end
---@param source parser.object
---@return vm.node?
function vm.getGenericResolved(source)
if source.type ~= 'doc.generic.name' then
return nil
end
return source._resolved
end
---@param source parser.object
---@param generic vm.generic
function vm.setGeneric(source, generic)
source._generic = generic
end
---@param source parser.object
---@return vm.generic?
function vm.getGeneric(source)
return source._generic
end
---@param proto vm.object
---@param sign vm.sign
---@return vm.generic
function vm.createGeneric(proto, sign)
local generic = setmetatable({
sign = sign,
proto = proto,
}, mt)
return generic
end

View File

@@ -0,0 +1,740 @@
local util = require 'utility'
local scope = require 'workspace.scope'
local guide = require 'parser.guide'
local config = require 'config'
---@class vm
local vm = require 'vm.vm'
---@type table<string, vm.global>
local allGlobals = {}
---@type table<uri, table<string, boolean>>
local globalSubs = util.multiTable(2)
---@class parser.object
---@field package _globalBase parser.object
---@field package _globalBaseMap table<string, parser.object>
---@field global vm.global
---@class vm.global.link
---@field sets parser.object[]
---@field gets parser.object[]
---@class vm.global
---@field links table<uri, vm.global.link>
---@field setsCache? table<uri, parser.object[]>
---@field cate vm.global.cate
local mt = {}
mt.__index = mt
mt.type = 'global'
mt.name = ''
---@param uri uri
---@param source parser.object
function mt:addSet(uri, source)
local link = self.links[uri]
link.sets[#link.sets+1] = source
self.setsCache = nil
end
---@param uri uri
---@param source parser.object
function mt:addGet(uri, source)
local link = self.links[uri]
link.gets[#link.gets+1] = source
end
---@param suri uri
---@return parser.object[]
function mt:getSets(suri)
if not self.setsCache then
self.setsCache = {}
end
local scp = scope.getScope(suri)
local cacheUri = scp.uri or '<callback>'
if self.setsCache[cacheUri] then
return self.setsCache[cacheUri]
end
local clock = os.clock()
self.setsCache[cacheUri] = {}
local cache = self.setsCache[cacheUri]
for uri, link in pairs(self.links) do
if link.sets then
if scp:isVisible(uri) then
for _, source in ipairs(link.sets) do
cache[#cache+1] = source
end
end
end
end
local cost = os.clock() - clock
if cost > 0.1 then
log.warn('global-manager getSets costs', cost, self.name)
end
return cache
end
---@return parser.object[]
function mt:getAllSets()
if not self.setsCache then
self.setsCache = {}
end
local cache = self.setsCache['*']
if cache then
return cache
end
cache = {}
self.setsCache['*'] = cache
for _, link in pairs(self.links) do
if link.sets then
for _, source in ipairs(link.sets) do
cache[#cache+1] = source
end
end
end
return cache
end
---@param uri uri
function mt:dropUri(uri)
self.links[uri] = nil
self.setsCache = nil
end
---@return string
function mt:getName()
return self.name
end
---@return string
function mt:getCodeName()
return (self.name:gsub(vm.ID_SPLITE, '.'))
end
---@return string
function mt:asKeyName()
return self.cate .. '|' .. self.name
end
---@return string
function mt:getKeyName()
return self.name:match('[^' .. vm.ID_SPLITE .. ']+$')
end
---@return string?
function mt:getFieldName()
return self.name:match(vm.ID_SPLITE .. '(.-)$')
end
---@return boolean
function mt:isAlive()
return next(self.links) ~= nil
end
---@param uri uri
---@return parser.object?
function mt:getParentBase(uri)
local parentID = self.name:match('^(.-)' .. vm.ID_SPLITE)
if not parentID then
return nil
end
local parentName = self.cate .. '|' .. parentID
local global = allGlobals[parentName]
if not global then
return nil
end
local link = global.links[uri]
if not link then
return nil
end
local luckyBoy = link.sets[1] or link.gets[1]
if not luckyBoy then
return nil
end
return vm.getGlobalBase(luckyBoy)
end
---@param cate vm.global.cate
---@return vm.global
local function createGlobal(name, cate)
return setmetatable({
name = name,
cate = cate,
links = util.multiTable(2, function ()
return {
sets = {},
gets = {},
}
end),
}, mt)
end
---@class parser.object
---@field package _globalNode vm.global|false
---@field package _enums? parser.object[]
local compileObject
local compilerGlobalSwitch = util.switch()
: case 'local'
: call(function (source)
if source.special ~= '_G' then
return
end
if source.ref then
for _, ref in ipairs(source.ref) do
compileObject(ref)
end
end
end)
: case 'getlocal'
: call(function (source)
if source.special ~= '_G' then
return
end
if not source.next then
return
end
compileObject(source.next)
end)
: case 'setglobal'
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
if not name then
return
end
local global = vm.declareGlobal('variable', name, uri)
global:addSet(uri, source)
source._globalNode = global
end)
: case 'getglobal'
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
if not name then
return
end
local global = vm.declareGlobal('variable', name, uri)
global:addGet(uri, source)
source._globalNode = global
local nxt = source.next
if nxt then
compileObject(nxt)
end
end)
: case 'setfield'
: case 'setmethod'
: case 'setindex'
---@param source parser.object
: call(function (source)
local name
local keyName = guide.getKeyName(source)
if not keyName then
return
end
if source.node._globalNode then
local parentName = source.node._globalNode:getName()
if parentName == '_G' then
name = keyName
else
name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
end
elseif source.node.special == '_G' then
name = keyName
end
if not name then
return
end
local uri = guide.getUri(source)
local global = vm.declareGlobal('variable', name, uri)
global:addSet(uri, source)
source._globalNode = global
end)
: case 'getfield'
: case 'getmethod'
: case 'getindex'
---@param source parser.object
: call(function (source)
local name
local keyName = guide.getKeyName(source)
if not keyName then
return
end
if source.node._globalNode then
local parentName = source.node._globalNode:getName()
if parentName == '_G' then
name = keyName
else
name = ('%s%s%s'):format(parentName, vm.ID_SPLITE, keyName)
end
elseif source.node.special == '_G' then
name = keyName
end
local uri = guide.getUri(source)
local global = vm.declareGlobal('variable', name, uri)
global:addGet(uri, source)
source._globalNode = global
local nxt = source.next
if nxt then
compileObject(nxt)
end
end)
: case 'call'
: call(function (source)
if source.node.special == 'rawset'
or source.node.special == 'rawget' then
if not source.args then
return
end
local g = source.args[1]
local key = source.args[2]
if g and key and g.special == '_G' then
local name = guide.getKeyName(key)
if name then
local uri = guide.getUri(source)
local global = vm.declareGlobal('variable', name, uri)
if source.node.special == 'rawset' then
global:addSet(uri, source)
source.value = source.args[3]
else
global:addGet(uri, source)
end
source._globalNode = global
local nxt = source.next
if nxt then
compileObject(nxt)
end
end
end
end
end)
: case 'doc.class'
---@param source parser.object
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
if not name then
return
end
local class = vm.declareGlobal('type', name, uri)
class:addSet(uri, source)
source._globalNode = class
if source.signs then
local sign = vm.createSign()
vm.setSign(source, sign)
for _, obj in ipairs(source.signs) do
sign:addSign(vm.compileNode(obj))
end
if source.extends then
for _, ext in ipairs(source.extends) do
if ext.type == 'doc.type.table' then
vm.setGeneric(ext, vm.createGeneric(ext, sign))
end
end
end
end
end)
: case 'doc.alias'
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
if not name then
return
end
local alias = vm.declareGlobal('type', name, uri)
alias:addSet(uri, source)
source._globalNode = alias
if source.signs then
source._sign = vm.createSign()
for _, sign in ipairs(source.signs) do
source._sign:addSign(vm.compileNode(sign))
end
source.extends._generic = vm.createGeneric(source.extends, source._sign)
end
end)
: case 'doc.enum'
: call(function (source)
local uri = guide.getUri(source)
local name = guide.getKeyName(source)
if not name then
return
end
local enum = vm.declareGlobal('type', name, uri)
enum:addSet(uri, source)
source._globalNode = enum
local tbl = source.bindSource
if not tbl then
return
end
source._enums = {}
if vm.docHasAttr(source, 'key') then
for _, field in ipairs(tbl) do
if field.type == 'tablefield' then
source._enums[#source._enums+1] = {
type = 'doc.type.string',
start = field.field.start,
finish = field.field.finish,
[1] = field.field[1],
}
elseif field.type == 'tableindex' then
if field.index then
source._enums[#source._enums+1] = {
type = 'doc.type.string',
start = field.index.start,
finish = field.index.finish,
[1] = field.index[1],
}
end
end
end
else
for _, field in ipairs(tbl) do
if field.type == 'tablefield' then
source._enums[#source._enums+1] = field
local subType = vm.declareGlobal('type', name .. '.' .. field.field[1], uri)
subType:addSet(uri, field)
elseif field.type == 'tableindex' then
source._enums[#source._enums+1] = field
if field.index.type == 'string' then
local subType = vm.declareGlobal('type', name .. '.' .. field.index[1], uri)
subType:addSet(uri, field)
end
end
end
end
end)
: case 'doc.type.name'
: call(function (source)
local uri = guide.getUri(source)
local name = source[1]
if name == '_' then
return
end
if name == 'self' then
return
end
local type = vm.declareGlobal('type', name, uri)
type:addGet(uri, source)
source._globalNode = type
end)
: case 'doc.extends.name'
: call(function (source)
local uri = guide.getUri(source)
local name = source[1]
local class = vm.declareGlobal('type', name, uri)
class:addGet(uri, source)
source._globalNode = class
end)
---@alias vm.global.cate '"variable"' | '"type"'
---@param cate vm.global.cate
---@param name string
---@param uri? uri
---@return vm.global
function vm.declareGlobal(cate, name, uri)
local key = cate .. '|' .. name
if uri then
globalSubs[uri][key] = true
end
if not allGlobals[key] then
allGlobals[key] = createGlobal(name, cate)
end
return allGlobals[key]
end
---@param cate vm.global.cate
---@param name string
---@param field? string
---@return vm.global?
function vm.getGlobal(cate, name, field)
local key = cate .. '|' .. name
if field then
key = key .. vm.ID_SPLITE .. field
end
return allGlobals[key]
end
---@param cate vm.global.cate
---@param name string
---@return vm.global[]
function vm.getGlobalFields(cate, name)
local globals = {}
local key = cate .. '|' .. name
local clock = os.clock()
for gid, global in pairs(allGlobals) do
if gid ~= key
and util.stringStartWith(gid, key)
and gid:sub(#key + 1, #key + 1) == vm.ID_SPLITE
and not gid:find(vm.ID_SPLITE, #key + 2) then
globals[#globals+1] = global
end
end
local cost = os.clock() - clock
if cost > 0.1 then
log.warn('global-manager getFields costs', cost)
end
return globals
end
---@param cate vm.global.cate
---@return vm.global[]
function vm.getGlobals(cate)
local globals = {}
local clock = os.clock()
for gid, global in pairs(allGlobals) do
if util.stringStartWith(gid, cate)
and not gid:find(vm.ID_SPLITE) then
globals[#globals+1] = global
end
end
local cost = os.clock() - clock
if cost > 0.1 then
log.warn('global-manager getGlobals costs', cost)
end
return globals
end
---@return table<string, vm.global>
function vm.getAllGlobals()
return allGlobals
end
---@param suri uri
---@param cate vm.global.cate
---@return parser.object[]
function vm.getGlobalSets(suri, cate)
local globals = vm.getGlobals(cate)
local result = {}
for _, global in ipairs(globals) do
local sets = global:getSets(suri)
for _, set in ipairs(sets) do
result[#result+1] = set
end
end
return result
end
---@param suri uri
---@param cate vm.global.cate
---@param name string
---@return boolean
function vm.hasGlobalSets(suri, cate, name)
local global = vm.getGlobal(cate, name)
if not global then
return false
end
local sets = global:getSets(suri)
if #sets == 0 then
return false
end
return true
end
---@param uri uri
---@param key string
---@return boolean
local function checkIsGlobalRegex(uri, key)
local dglobalsregex = config.get(uri, 'Lua.diagnostics.globalsRegex')
if not dglobalsregex then
return false
end
for _, pattern in ipairs(dglobalsregex) do
if key:match(pattern) then
return true
end
end
return false
end
---@param src parser.object
local function checkIsUndefinedGlobal(src)
if src.type ~= 'getglobal' then
return false
end
local key = src[1]
if not key then
return false
end
local node = src.node
if node.tag ~= '_ENV' then
return false
end
local uri = guide.getUri(src)
local rspecial = config.get(uri, 'Lua.runtime.special')
if rspecial[key] then
return false
end
if vm.hasGlobalSets(uri, 'variable', key) then
return false
end
local dglobals = config.get(uri, 'Lua.diagnostics.globals')
if util.arrayHas(dglobals, key) then
return false
end
if checkIsGlobalRegex(uri, key) then
return false
end
return true
end
---@param src parser.object
---@return boolean
function vm.isUndefinedGlobal(src)
local node = vm.compileNode(src)
if node.undefinedGlobal == nil then
node.undefinedGlobal = checkIsUndefinedGlobal(src)
end
return node.undefinedGlobal
end
---@param source parser.object
function compileObject(source)
if source._globalNode ~= nil then
return
end
source._globalNode = false
compilerGlobalSwitch(source.type, source)
end
---@param source parser.object
---@return vm.global?
function vm.getGlobalNode(source)
return source._globalNode or nil
end
---@param source parser.object
---@return parser.object[]?
function vm.getEnums(source)
return source._enums
end
---@param source parser.object
---@return boolean
function vm.compileByGlobal(source)
local global = vm.getGlobalNode(source)
if not global then
return false
end
vm.setNode(source, global)
if global.cate == 'variable' then
if guide.isAssign(source) then
if vm.bindDocs(source) then
return true
end
if source.value and source.value.type ~= 'nil' then
vm.setNode(source, vm.compileNode(source.value))
return true
end
else
if vm.bindAs(source) then
return true
end
local node = vm.traceNode(source)
if node then
vm.setNode(source, node, true)
return true
end
end
end
local globalBase = vm.getGlobalBase(source)
if not globalBase then
return false
end
local globalNode = vm.compileNode(globalBase)
vm.setNode(source, globalNode, true)
return true
end
---@param source parser.object
---@return parser.object?
function vm.getGlobalBase(source)
if source._globalBase then
return source._globalBase
end
local global = vm.getGlobalNode(source)
if not global then
return nil
end
---@cast source parser.object
local root = guide.getRoot(source)
if not root._globalBaseMap then
root._globalBaseMap = {}
end
local name = global:asKeyName()
if not root._globalBaseMap[name] then
---@diagnostic disable-next-line: missing-fields
root._globalBaseMap[name] = {
type = 'globalbase',
parent = root,
global = global,
start = 0,
finish = 0,
}
end
source._globalBase = root._globalBaseMap[name]
return source._globalBase
end
---@param source parser.object
local function compileAst(source)
local env = guide.getENV(source)
if not env then
return
end
compileObject(env)
guide.eachSpecialOf(source, 'rawset', function (src)
compileObject(src.parent)
end)
guide.eachSpecialOf(source, 'rawget', function (src)
compileObject(src.parent)
end)
guide.eachSourceTypes(source.docs, {
'doc.class',
'doc.alias',
'doc.type.name',
'doc.extends.name',
'doc.enum',
}, function (src)
compileObject(src)
end)
end
---@param uri uri
local function dropUri(uri)
local globalSub = globalSubs[uri]
globalSubs[uri] = nil
for key in pairs(globalSub) do
local global = allGlobals[key]
if global then
global:dropUri(uri)
if not global:isAlive() then
allGlobals[key] = nil
end
end
end
end
return {
compileAst = compileAst,
dropUri = dropUri,
}

View File

@@ -0,0 +1,615 @@
local util = require 'utility'
local config = require 'config'
local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
---@class vm.infer
---@field node vm.node
---@field views table<string, boolean>
---@field _drop table
---@field _lastView? string
---@field _lastViewUri? uri
---@field _lastViewDefault? any
---@field _subViews? string[]
local mt = {}
mt.__index = mt
mt._hasTable = false
mt._hasClass = false
mt._hasFunctionDef = false
mt._hasDocFunction = false
mt._isParam = false
mt._isLocal = false
vm.NULL = setmetatable({}, mt)
local LOCK = {}
local inferSorted = {
['boolean'] = - 100,
['string'] = - 99,
['number'] = - 98,
['integer'] = - 97,
['function'] = - 96,
['table'] = - 95,
['true'] = 1,
['false'] = 2,
['nil'] = 100,
}
local viewNodeSwitch;viewNodeSwitch = util.switch()
: case 'nil'
: case 'boolean'
: case 'string'
: case 'integer'
: call(function (source, _infer)
return source.type
end)
: case 'number'
: call(function (source, _infer)
return source.type
end)
: case 'table'
: call(function (source, infer, uri)
local docs = source.bindDocs
if docs then
for _, doc in ipairs(docs) do
if doc.type == 'doc.enum' then
return 'enum ' .. doc.enum[1]
end
end
end
if #source == 1 and source[1].type == 'varargs' then
local node = vm.getInfer(source[1]):view(uri)
return ('%s[]'):format(node)
end
infer._hasTable = true
end)
: case 'function'
: call(function (source, infer)
local parent = source.parent
if guide.isAssign(parent) then
infer._hasFunctionDef = true
end
return source.type
end)
: case 'local'
: call(function (source, infer)
if source.parent == 'funcargs' then
infer._isParam = true
else
infer._isLocal = true
end
end)
: case 'global'
: call(function (source, infer)
if source.cate == 'type' then
if not guide.isBasicType(source.name) then
infer._hasClass = true
end
return source.name
end
end)
: case 'doc.type'
: call(function (source, infer, uri)
local buf = {}
for _, tp in ipairs(source.types) do
buf[#buf+1] = viewNodeSwitch(tp.type, tp, infer, uri)
end
return table.concat(buf, '|')
end)
: case 'doc.type.name'
: call(function (source, _infer, uri)
if source.signs then
local buf = {}
for i, sign in ipairs(source.signs) do
buf[i] = vm.getInfer(sign):view(uri)
end
return ('%s<%s>'):format(source[1], table.concat(buf, ', '))
else
return source[1]
end
end)
: case 'generic'
: call(function (source, _infer, uri)
return vm.getInfer(source.proto):view(uri)
end)
: case 'doc.generic.name'
: call(function (source, _infer, uri)
local resolved = vm.getGenericResolved(source)
if resolved then
return vm.getInfer(resolved):view(uri)
end
if source.generic and source.generic.extends then
return ('<%s:%s>'):format(source[1], vm.getInfer(source.generic.extends):view(uri))
else
return ('<%s>'):format(source[1])
end
end)
: case 'doc.type.array'
: call(function (source, infer, uri)
infer._hasClass = true
local view = vm.getInfer(source.node):view(uri)
if source.node.type == 'doc.type' then
view = '(' .. view .. ')'
end
return view .. '[]'
end)
: case 'doc.type.sign'
: call(function (source, infer, uri)
infer._hasClass = true
local buf = {}
for i, sign in ipairs(source.signs) do
buf[i] = vm.getInfer(sign):view(uri)
end
local node = vm.compileNode(source)
for c in node:eachObject() do
if guide.isLiteral(c) then
---@cast c parser.object
local view = vm.getInfer(c):view(uri)
if view then
infer._drop[view] = true
end
end
end
return ('%s<%s>'):format(source.node[1], table.concat(buf, ', '))
end)
: case 'doc.type.table'
: call(function (source, infer, uri)
if #source.fields == 0 then
infer._hasTable = true
return
end
infer._hasClass = true
local buf = {}
buf[#buf+1] = source.isTuple and '[' or '{ '
for i, field in ipairs(source.fields) do
if i > 1 then
buf[#buf+1] = ', '
end
if not source.isTuple then
local key = field.name
if key.type == 'doc.type' then
buf[#buf+1] = ('[%s]: '):format(vm.getInfer(key):view(uri))
elseif type(key[1]) == 'string' then
buf[#buf+1] = key[1] .. ': '
else
buf[#buf+1] = ('[%q]: '):format(key[1])
end
end
buf[#buf+1] = vm.getInfer(field.extends):view(uri)
end
buf[#buf+1] = source.isTuple and ']' or ' }'
return table.concat(buf)
end)
: case 'doc.type.string'
: call(function (source, _infer)
return util.viewString(source[1], source[2])
end)
: case 'doc.type.integer'
: case 'doc.type.boolean'
: call(function (source, _infer)
return ('%q'):format(source[1])
end)
: case 'doc.type.code'
: call(function (source, _infer)
return ('`%s`'):format(source[1])
end)
: case 'doc.type.function'
: call(function (source, infer, uri)
infer._hasDocFunction = true
local args = {}
local rets = {}
local argView = ''
local regView = ''
for i, arg in ipairs(source.args) do
local argNode = vm.compileNode(arg)
local isOptional = argNode:isOptional()
if isOptional then
argNode = argNode:copy()
argNode:removeOptional()
end
args[i] = string.format('%s%s%s%s'
, arg.name[1]
, isOptional and '?' or ''
, arg.name[1] == '...' and '' or ': '
, vm.getInfer(argNode):view(uri)
)
end
if #args > 0 then
argView = table.concat(args, ', ')
end
local needReturnParen
for i, ret in ipairs(source.returns) do
local retType = vm.getInfer(ret):view(uri)
if ret.name then
if ret.name[1] == '...' then
rets[i] = ('%s%s'):format(ret.name[1], retType)
else
needReturnParen = true
rets[i] = ('%s: %s'):format(ret.name[1], retType)
end
else
rets[i] = retType
end
end
if #rets > 0 then
if needReturnParen then
regView = (':(%s)'):format(table.concat(rets, ', '))
else
regView = (':%s'):format(table.concat(rets, ', '))
end
end
return ('fun(%s)%s'):format(argView, regView)
end)
: case 'doc.field.name'
: call(function (source, _infer, uri)
return vm.viewKey(source, uri)
end)
---@param node? vm.node
---@return vm.infer
local function createInfer(node)
local infer = setmetatable({
node = node,
_drop = {},
}, mt)
return infer
end
---@param source vm.node.object | vm.node
---@return vm.infer
function vm.getInfer(source)
---@type vm.node
local node
if source.type == 'vm.node' then
---@cast source vm.node
node = source
else
---@cast source vm.object
node = vm.compileNode(source)
end
if node.lastInfer then
return node.lastInfer
end
local infer = createInfer(node)
node.lastInfer = infer
return infer
end
function mt:_trim()
if self._hasDocFunction then
if self._hasFunctionDef then
for view in pairs(self.views) do
if view:sub(1, 4) == 'fun(' then
self.views[view] = nil
end
end
else
self.views['function'] = nil
end
end
if self._hasTable and not self._hasClass then
self.views['table'] = true
end
if self.views['number'] then
self.views['integer'] = nil
end
if self.views['boolean'] then
self.views['true'] = nil
self.views['false'] = nil
end
end
---@param uri uri
function mt:_eraseAlias(uri)
local count = 0
for _ in pairs(self.views) do
count = count + 1
end
if count <= 1 then
return
end
local expandAlias = config.get(uri, 'Lua.hover.expandAlias')
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
if LOCK[n.name] then
goto CONTINUE
end
LOCK[n.name] = true
for _, set in ipairs(n:getSets(uri)) do
if set.type == 'doc.alias' then
if expandAlias then
self._drop[n.name] = true
local newInfer = createInfer()
for _, ext in ipairs(set.extends.types) do
viewNodeSwitch(ext.type, ext, newInfer, uri)
end
if newInfer._hasTable then
self.views['table'] = true
end
else
for _, ext in ipairs(set.extends.types) do
local view = viewNodeSwitch(ext.type, ext, createInfer(), uri)
if view and view ~= n.name then
self._drop[view] = true
end
end
end
end
end
LOCK[n.name] = nil
::CONTINUE::
end
end
end
---@param uri uri
---@param tp string
---@return boolean
function mt:hasType(uri, tp)
self:_computeViews(uri)
return self.views[tp] == true
end
---@param uri uri
function mt:hasUnknown(uri)
self:_computeViews(uri)
return not next(self.views)
or self.views['unknown'] == true
end
---@param uri uri
function mt:hasAny(uri)
self:_computeViews(uri)
return self.views['any'] == true
end
---@param uri uri
---@return boolean
function mt:hasClass(uri)
self:_computeViews(uri)
return self._hasClass == true
end
---@param uri uri
---@return boolean
function mt:hasFunction(uri)
self:_computeViews(uri)
return self.views['function'] == true
or self._hasDocFunction == true
end
---@param uri uri
function mt:_computeViews(uri)
if self.views then
return
end
self.views = {}
for n in self.node:eachObject() do
if not n.hideView then
local view = viewNodeSwitch(n.type, n, self, uri)
if view then
self.views[view] = true
end
end
end
self:_trim()
end
---@param uri uri
---@param default? string
---@return string
function mt:view(uri, default)
if self._lastView
and self._lastViewUri == uri
and self._lastViewDefault == default then
return self._lastView
end
self._lastViewUri = uri
self._lastViewDefault = default
self:_computeViews(uri)
if self.views['any'] then
self._lastView = 'any'
return 'any'
end
if self._hasClass then
self:_eraseAlias(uri)
end
local array = {}
self._subViews = array
for view in pairs(self.views) do
if not self._drop[view] then
array[#array+1] = view
end
end
table.sort(array, function (a, b)
local sa = inferSorted[a] or 0
local sb = inferSorted[b] or 0
if sa == sb then
return a < b
end
return sa < sb
end)
local max = #array
local limit = config.get(uri, 'Lua.hover.enumsLimit')
local view
if #array == 0 then
view = default or 'unknown'
else
if max > limit then
view = string.format('%s...(+%d)'
, table.concat(array, '|', 1, limit)
, max - limit
)
else
view = table.concat(array, '|')
end
end
if self.node:isOptional() then
if #array == 0 then
view = 'nil'
else
if max > 1
or view:find(guide.notNamePattern .. guide.namePattern .. '$') then
view = '(' .. view .. ')?'
else
view = view .. '?'
end
end
end
-- do not truncate if exporting doc
if not DOC and #view > 200 then
view = view:sub(1, 180) .. '...(too long)...' .. view:sub(-10)
end
self._lastView = view
return view
end
---@param uri uri
function mt:eachView(uri)
self:_computeViews(uri)
return next, self.views
end
---@param uri uri
---@return string[]
function mt:getSubViews(uri)
self:view(uri)
return self._subViews
end
---@return string?
function mt:viewLiterals()
if not self.node then
return nil
end
local mark = {}
local literals = {}
for n in self.node:eachObject() do
if n.type == 'string'
or n.type == 'number'
or n.type == 'integer'
or n.type == 'boolean' then
local literal
if n.type == 'string' then
literal = util.viewString(n[1], n[2])
else
literal = util.viewLiteral(n[1])
end
if literal and not mark[literal] then
literals[#literals+1] = literal
mark[literal] = true
end
end
end
if #literals == 0 then
return nil
end
table.sort(literals, function (a, b)
local sa = inferSorted[a] or 0
local sb = inferSorted[b] or 0
if sa == sb then
return a < b
end
return sa < sb
end)
return table.concat(literals, '|')
end
---@return string?
function mt:viewClass()
if not self.node then
return nil
end
local mark = {}
local class = {}
for n in self.node:eachObject() do
if n.type == 'global' and n.cate == 'type' then
local name = n.name
if not mark[name] then
class[#class+1] = name
mark[name] = true
end
end
end
if #class == 0 then
return nil
end
table.sort(class)
return table.concat(class, '|')
end
---@param source vm.node.object
---@param uri uri
---@return string?
function vm.viewObject(source, uri)
local infer = createInfer()
return viewNodeSwitch(source.type, source, infer, uri)
end
---@param source parser.object
---@param uri uri
---@return string?
---@return string|number|boolean|nil
function vm.viewKey(source, uri)
if source.type == 'doc.type' then
if #source.types == 1 then
return vm.viewKey(source.types[1], uri)
else
local key = vm.getInfer(source):view(uri)
return '[' .. key .. ']', key
end
end
if source.type == 'tableindex'
or source.type == 'setindex'
or source.type == 'getindex' then
local index = source.index
local name = vm.getInfer(index):viewLiterals()
if not name then
return nil
end
return ('[%s]'):format(name), name
end
if source.type == 'tableexp' then
return ('[%d]'):format(source.tindex), source.tindex
end
if source.type == 'doc.field' then
return vm.viewKey(source.field, uri)
end
if source.type == 'doc.type.field' then
return vm.viewKey(source.name, uri)
end
if source.type == 'doc.type.name' then
return '[' .. source[1] .. ']', source[1]
end
if source.type == 'doc.type.string' then
local name = util.viewString(source[1], source[2])
return ('[%s]'):format(name), name
end
local key = vm.getKeyName(source)
if key == nil then
return nil
end
if type(key) == 'string' then
return key, key
else
return ('[%s]'):format(key), key
end
end

View File

@@ -0,0 +1,25 @@
local vm = require 'vm.vm'
---@alias vm.object parser.object | vm.generic
require 'vm.compiler'
require 'vm.value'
require 'vm.node'
require 'vm.def'
require 'vm.ref'
require 'vm.field'
require 'vm.doc'
require 'vm.type'
require 'vm.library'
require 'vm.tracer'
require 'vm.infer'
require 'vm.generic'
require 'vm.sign'
require 'vm.variable'
require 'vm.global'
require 'vm.function'
require 'vm.operator'
require 'vm.visible'
require 'vm.precompile'
return vm

View File

@@ -0,0 +1,15 @@
---@class vm
local vm = require 'vm.vm'
function vm.getLibraryName(source)
if source.special then
return source.special
end
local defs = vm.getDefs(source)
for _, def in ipairs(defs) do
if def.special then
return def.special
end
end
return nil
end

View File

@@ -0,0 +1,537 @@
local files = require 'files'
---@class vm
local vm = require 'vm.vm'
local ws = require 'workspace.workspace'
local guide = require 'parser.guide'
local timer = require 'timer'
local util = require 'utility'
---@type table<vm.object, vm.node>
vm.nodeCache = setmetatable({}, util.MODE_K)
---@alias vm.node.object vm.object | vm.global | vm.variable
---@class vm.node
---@field [integer] vm.node.object
---@field [vm.node.object] true
---@field fields? table<vm.node|string, vm.node>
---@field undefinedGlobal boolean?
---@field lastInfer? vm.infer
local mt = {}
mt.__index = mt
mt.id = 0
mt.type = 'vm.node'
mt.optional = nil
mt.data = nil
mt.hasDefined = nil
mt.originNode = nil
---@param node vm.node | vm.node.object
---@return vm.node
function mt:merge(node)
if not node then
return self
end
self.lastInfer = nil
if node.type == 'vm.node' then
if node == self then
return self
end
if node:isOptional() then
self.optional = true
end
for _, obj in ipairs(node) do
if not self[obj] then
self[obj] = true
self[#self+1] = obj
end
end
else
---@cast node -vm.node
if not self[node] then
self[node] = true
self[#self+1] = node
end
end
return self
end
---@return boolean
function mt:isEmpty()
return #self == 0
end
---@return boolean
function mt:isTyped()
for _, c in ipairs(self) do
if c.type == 'global' and c.cate == 'type' then
return true
end
if guide.isLiteral(c) then
return true
end
end
return false
end
function mt:clear()
self.optional = nil
for i, c in ipairs(self) do
self[i] = nil
self[c] = nil
end
end
---@param n integer
---@return vm.node.object?
function mt:get(n)
return self[n]
end
function mt:addOptional()
self.optional = true
end
function mt:removeOptional()
self:remove 'nil'
return self
end
---@return boolean
function mt:isOptional()
return self.optional == true
end
---@return boolean
function mt:hasFalsy()
if self.optional then
return true
end
for _, c in ipairs(self) do
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
return true
end
end
return false
end
---Almost an inverse of hasFalsy, but stricter about "any" and "unknown" types.
---@return boolean
function mt:alwaysTruthy()
if self.optional then
return false
end
if #self == 0 then
return false
end
for _, c in ipairs(self) do
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'global' and c.cate == 'type' and c.name == 'any')
or (c.type == 'global' and c.cate == 'type' and c.name == 'boolean')
or (c.type == 'global' and c.cate == 'type' and c.name == 'doc.type.boolean')
or (c.type == 'global' and c.cate == 'type' and c.name == 'unknown')
or not self:hasKnownType()
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
return false
end
end
return true
end
---@return boolean
function mt:hasKnownType()
for _, c in ipairs(self) do
if c.type == 'global' and c.cate == 'type' then
return true
end
if guide.isLiteral(c) then
return true
end
end
return false
end
---@return boolean
function mt:isNullable()
if self.optional then
return true
end
if #self == 0 then
return true
end
for _, c in ipairs(self) do
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'any')
or (c.type == 'global' and c.cate == 'type' and c.name == '...') then
return true
end
end
return false
end
---@return vm.node
function mt:setTruthy()
if self.optional == true then
self.optional = nil
end
local hasBoolean
for index = #self, 1, -1 do
local c = self[index]
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then
hasBoolean = true
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
if c.type == 'boolean' or c.type == 'doc.type.boolean' then
if c[1] == false then
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
end
::CONTINUE::
end
if hasBoolean then
self:merge(vm.declareGlobal('type', 'true'))
end
return self
end
---@return vm.node
function mt:setFalsy()
if self.optional == false then
self.optional = nil
end
local hasBoolean
for index = #self, 1, -1 do
local c = self[index]
if c.type == 'nil'
or (c.type == 'global' and c.cate == 'type' and c.name == 'nil')
or (c.type == 'global' and c.cate == 'type' and c.name == 'false')
or (c.type == 'boolean' and c[1] == false)
or (c.type == 'doc.type.boolean' and c[1] == false) then
goto CONTINUE
end
if c.type == 'global' and c.cate == 'type' and c.name == 'boolean' then
hasBoolean = true
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
if c.type == 'boolean' or c.type == 'doc.type.boolean' then
if c[1] == true then
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
end
if (c.type == 'global' and c.cate == 'type') then
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
if guide.isLiteral(c) then
table.remove(self, index)
self[c] = nil
goto CONTINUE
end
::CONTINUE::
end
if hasBoolean then
self:merge(vm.declareGlobal('type', 'false'))
end
return self
end
---@param name string
function mt:remove(name)
if name == 'nil' and self.optional == true then
self.optional = nil
end
for index = #self, 1, -1 do
local c = self[index]
if (c.type == 'global' and c.cate == 'type' and c.name == name)
or (c.type == name)
or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
or (c.type == 'doc.type.boolean' and name == 'boolean')
or (c.type == 'doc.type.boolean' and name == 'true' and c[1] == true)
or (c.type == 'doc.type.boolean' and name == 'false' and c[1] == false)
or (c.type == 'doc.type.table' and name == 'table')
or (c.type == 'doc.type.array' and name == 'table')
or (c.type == 'doc.type.sign' and name == c.node[1])
or (c.type == 'doc.type.function' and name == 'function')
or (c.type == 'doc.type.string' and name == 'string') then
table.remove(self, index)
self[c] = nil
end
end
return self
end
---@param uri uri
---@param name string
function mt:narrow(uri, name)
if self.optional == true then
self.optional = nil
end
for index = #self, 1, -1 do
local c = self[index]
if (c.type == name)
or (c.type == 'doc.type.integer' and (name == 'number' or name == 'integer'))
or (c.type == 'doc.type.boolean' and name == 'boolean')
or (c.type == 'doc.type.table' and name == 'table')
or (c.type == 'doc.type.array' and name == 'table')
or (c.type == 'doc.type.sign' and name == c.node[1])
or (c.type == 'doc.type.function' and name == 'function')
or (c.type == 'doc.type.string' and name == 'string') then
goto CONTINUE
end
if c.type == 'global' and c.cate == 'type' then
if (c.name == name)
or (vm.isSubType(uri, c.name, name)) then
goto CONTINUE
end
end
table.remove(self, index)
self[c] = nil
::CONTINUE::
end
if #self == 0 then
self[#self+1] = vm.getGlobal('type', name)
end
return self
end
---@param obj vm.object | vm.variable
function mt:removeObject(obj)
for index, c in ipairs(self) do
if c == obj then
table.remove(self, index)
self[c] = nil
return
end
end
end
---@param node vm.node
function mt:removeNode(node)
for _, c in ipairs(node) do
if c.type == 'global' and c.cate == 'type' then
---@cast c vm.global
self:remove(c.name)
elseif c.type == 'nil' then
self:remove 'nil'
elseif c.type == 'boolean'
or c.type == 'doc.type.boolean' then
if c[1] == true then
self:remove 'true'
else
self:remove 'false'
end
else
---@cast c -vm.global
self:removeObject(c)
end
end
end
---@param name string
---@return boolean
function mt:hasType(name)
for _, c in ipairs(self) do
if c.type == 'global' and c.cate == 'type' and c.name == name then
return true
end
end
return false
end
---@param name string
---@return boolean
function mt:hasName(name)
if name == 'nil' and self.optional == true then
return true
end
for _, c in ipairs(self) do
if c.type == 'global' and c.cate == 'type' and c.name == name then
return true
end
if c.type == name then
return true
end
-- TODO
end
return false
end
---@return vm.node
function mt:asTable()
self.optional = nil
for index = #self, 1, -1 do
local c = self[index]
if c.type == 'table'
or c.type == 'doc.type.table'
or c.type == 'doc.type.array' then
goto CONTINUE
end
if c.type == 'doc.type.sign' then
if c.node[1] == 'table'
or not guide.isBasicType(c.node[1]) then
goto CONTINUE
end
end
if c.type == 'global' and c.cate == 'type' then
---@cast c vm.global
if c.name == 'table'
or not guide.isBasicType(c.name) then
goto CONTINUE
end
end
table.remove(self, index)
self[c] = nil
::CONTINUE::
end
return self
end
---@return fun():vm.node.object
function mt:eachObject()
local i = 0
return function ()
i = i + 1
return self[i]
end
end
---@return vm.node
function mt:copy()
return vm.createNode(self)
end
---@param source vm.node.object | vm.generic
---@param node vm.node | vm.node.object
---@param cover? boolean
---@return vm.node
function vm.setNode(source, node, cover)
if not node then
if TEST then
error('Can not set nil node')
else
log.error('Can not set nil node')
end
end
if cover then
---@cast node vm.node
vm.nodeCache[source] = node
return node
end
local me = vm.nodeCache[source]
if me then
me:merge(node)
else
if node.type == 'vm.node' then
me = node:copy()
else
me = vm.createNode(node)
end
vm.nodeCache[source] = me
end
return me
end
---@param source vm.node.object
---@return vm.node?
function vm.getNode(source)
return vm.nodeCache[source]
end
---@param source vm.object
function vm.removeNode(source)
vm.nodeCache[source] = nil
end
local lockCount = 0
local needClearCache = false
function vm.lockCache()
lockCount = lockCount + 1
end
function vm.unlockCache()
lockCount = lockCount - 1
if needClearCache then
needClearCache = false
vm.clearNodeCache()
end
end
function vm.clearNodeCache()
if lockCount > 0 then
needClearCache = true
return
end
log.debug('clearNodeCache')
vm.nodeCache = {}
end
local ID = 0
---@param a? vm.node | vm.node.object
---@param b? vm.node | vm.node.object
---@return vm.node
function vm.createNode(a, b)
ID = ID + 1
local node = setmetatable({
id = ID,
}, mt)
if a then
node:merge(a)
end
if b then
node:merge(b)
end
return node
end
---@type timer?
local delayTimer
files.watch(function (ev, uri)
if ev == 'version' then
if ws.isReady(uri) then
if CACHEALIVE then
if delayTimer then
delayTimer:restart()
end
delayTimer = timer.wait(1, function ()
delayTimer = nil
vm.clearNodeCache()
end)
else
vm.clearNodeCache()
end
end
end
end)
ws.watch(function (ev, _uri)
if ev == 'reload' then
vm.clearNodeCache()
end
end)

View File

@@ -0,0 +1,437 @@
---@class vm
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
local config = require 'config'
vm.UNARY_OP = {
'unm',
'bnot',
'len',
}
vm.BINARY_OP = {
'add',
'sub',
'mul',
'div',
'mod',
'pow',
'idiv',
'band',
'bor',
'bxor',
'shl',
'shr',
'concat',
}
vm.OTHER_OP = {
'call',
}
local unaryMap = {
['-'] = 'unm',
['~'] = 'bnot',
['#'] = 'len',
}
local binaryMap = {
['+'] = 'add',
['-'] = 'sub',
['*'] = 'mul',
['/'] = 'div',
['%'] = 'mod',
['^'] = 'pow',
['//'] = 'idiv',
['&'] = 'band',
['|'] = 'bor',
['~'] = 'bxor',
['<<'] = 'shl',
['>>'] = 'shr',
['..'] = 'concat',
}
local otherMap = {
['()'] = 'call',
}
vm.OP_UNARY_MAP = util.revertMap(unaryMap)
vm.OP_BINARY_MAP = util.revertMap(binaryMap)
vm.OP_OTHER_MAP = util.revertMap(otherMap)
---@param operators parser.object[]
---@param op string
---@param value? parser.object
---@param result? vm.node
---@return vm.node?
local function checkOperators(operators, op, value, result)
for _, operator in ipairs(operators) do
if operator.op[1] ~= op
or not operator.extends then
goto CONTINUE
end
if value and operator.exp then
local valueNode = vm.compileNode(value)
local expNode = vm.compileNode(operator.exp)
local uri = guide.getUri(operator)
for vo in valueNode:eachObject() do
if vm.isSubType(uri, vo, expNode) then
if not result then
result = vm.createNode()
end
result:merge(vm.compileNode(operator.extends))
return result
end
end
else
if not result then
result = vm.createNode()
end
result:merge(vm.compileNode(operator.extends))
return result
end
::CONTINUE::
end
return result
end
---@param op string
---@param exp parser.object
---@param value? parser.object
---@return vm.node?
function vm.runOperator(op, exp, value)
local uri = guide.getUri(exp)
local node = vm.compileNode(exp)
local result
for c in node:eachObject() do
if c.type == 'string'
or c.type == 'doc.type.string' then
c = vm.declareGlobal('type', 'string')
end
if c.type == 'global' and c.cate == 'type' then
---@cast c vm.global
for _, set in ipairs(c:getSets(uri)) do
if set.operators and #set.operators > 0 then
result = checkOperators(set.operators, op, value, result)
end
end
end
end
return result
end
vm.unarySwich = util.switch()
: case 'not'
: call(function (source)
local result = vm.testCondition(source[1])
if result == nil then
vm.setNode(source, vm.declareGlobal('type', 'boolean'))
else
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
finish = source.finish,
parent = source,
[1] = not result,
})
end
end)
: case '#'
: call(function (source)
local node = vm.runOperator('len', source[1])
vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
end)
: case '-'
: call(function (source)
local v = vm.getNumber(source[1])
if v == nil then
local uri = guide.getUri(source)
local infer = vm.getInfer(source[1])
if infer:hasType(uri, 'integer') then
vm.setNode(source, vm.declareGlobal('type', 'integer'))
elseif infer:hasType(uri, 'number') then
vm.setNode(source, vm.declareGlobal('type', 'number'))
else
local node = vm.runOperator('unm', source[1])
vm.setNode(source, node or vm.declareGlobal('type', 'number'))
end
else
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'number',
start = source.start,
finish = source.finish,
parent = source,
[1] = -v,
})
end
end)
: case '~'
: call(function (source)
local v = vm.getInteger(source[1])
if v == nil then
local node = vm.runOperator('bnot', source[1])
vm.setNode(source, node or vm.declareGlobal('type', 'integer'))
else
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'integer',
start = source.start,
finish = source.finish,
parent = source,
[1] = ~v,
})
end
end)
vm.binarySwitch = util.switch()
: case 'and'
: call(function (source)
local node1 = vm.compileNode(source[1])
local node2 = vm.compileNode(source[2])
local r1 = vm.testCondition(source[1])
if r1 == true then
vm.setNode(source, node2)
elseif r1 == false then
vm.setNode(source, node1)
else
local node = node1:copy():setFalsy():merge(node2)
vm.setNode(source, node)
end
end)
: case 'or'
: call(function (source)
local node1 = vm.compileNode(source[1])
local node2 = vm.compileNode(source[2])
local r1 = vm.testCondition(source[1])
if r1 == true then
vm.setNode(source, node1)
elseif r1 == false then
vm.setNode(source, node2)
else
local node = node1:copy():setTruthy()
if not source[2].hasExit then
node:merge(node2)
end
vm.setNode(source, node)
end
end)
: case '=='
: case '~='
: call(function (source)
local result = vm.equal(source[1], source[2])
if result == nil then
vm.setNode(source, vm.declareGlobal('type', 'boolean'))
else
if source.op.type == '~=' then
result = not result
end
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
finish = source.finish,
parent = source,
[1] = result,
})
end
end)
: case '<<'
: case '>>'
: case '&'
: case '|'
: case '~'
: call(function (source)
local a = vm.getInteger(source[1])
local b = vm.getInteger(source[2])
local op = source.op.type
if a and b then
local result = op == '<<' and a << b
or op == '>>' and a >> b
or op == '&' and a & b
or op == '|' and a | b
or op == '~' and a ~ b
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'integer',
start = source.start,
finish = source.finish,
parent = source,
[1] = result,
})
else
local node = vm.runOperator(binaryMap[op], source[1], source[2])
if not node then
node = vm.runOperator(binaryMap[op], source[2], source[1])
end
if node then
vm.setNode(source, node)
end
end
end)
: case '+'
: case '-'
: case '*'
: case '/'
: case '%'
: case '//'
: case '^'
: call(function (source)
local a = vm.getNumber(source[1])
local b = vm.getNumber(source[2])
local op = source.op.type
local zero = b == 0
and ( op == '%'
or op == '/'
or op == '//'
)
if a and b and not zero then
local result = op == '+' and a + b
or op == '-' and a - b
or op == '*' and a * b
or op == '/' and a / b
or op == '%' and a % b
or op == '//' and a // b
or op == '^' and a ^ b
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = (op == '//' or math.type(result) == 'integer') and 'integer' or 'number',
start = source.start,
finish = source.finish,
parent = source,
[1] = result,
})
else
local node = vm.runOperator(binaryMap[op], source[1], source[2])
if not node then
node = vm.runOperator(binaryMap[op], source[2], source[1])
end
if node then
vm.setNode(source, node)
return
end
if op == '+'
or op == '-'
or op == '*'
or op == '%' then
local uri = guide.getUri(source)
local infer1 = vm.getInfer(source[1])
local infer2 = vm.getInfer(source[2])
if infer1:hasType(uri, 'integer')
and infer2:hasType(uri, 'integer') then
vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
if (infer1:hasType(uri, 'number') or infer1:hasType(uri, 'integer'))
and (infer2:hasType(uri, 'number') or infer2:hasType(uri, 'integer')) then
vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
if op == '/'
or op == '^' then
local uri = guide.getUri(source)
local infer1 = vm.getInfer(source[1])
local infer2 = vm.getInfer(source[2])
if (infer1:hasType(uri, 'integer') or infer1:hasType(uri, 'number'))
and (infer2:hasType(uri, 'integer') or infer2:hasType(uri, 'number')) then
vm.setNode(source, vm.declareGlobal('type', 'number'))
return
end
end
if op == '//' then
local uri = guide.getUri(source)
local infer1 = vm.getInfer(source[1])
local infer2 = vm.getInfer(source[2])
if (infer1:hasType(uri, 'integer') or infer1:hasType(uri, 'number'))
and (infer2:hasType(uri, 'integer') or infer2:hasType(uri, 'number')) then
vm.setNode(source, vm.declareGlobal('type', 'integer'))
return
end
end
end
end)
: case '..'
: call(function (source)
local a = vm.getString(source[1])
or vm.getNumber(source[1])
local b = vm.getString(source[2])
or vm.getNumber(source[2])
if a and b then
if type(a) == 'number' or type(b) == 'number' then
local uri = guide.getUri(source)
local version = config.get(uri, 'Lua.runtime.version')
if math.tointeger(a) and math.type(a) == 'float' then
if version == 'Lua 5.3' or version == 'Lua 5.4' then
a = ('%.1f'):format(a)
else
a = ('%.0f'):format(a)
end
end
if math.tointeger(b) and math.type(b) == 'float' then
if version == 'Lua 5.3' or version == 'Lua 5.4' then
b = ('%.1f'):format(b)
else
b = ('%.0f'):format(b)
end
end
end
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'string',
start = source.start,
finish = source.finish,
parent = source,
[1] = a .. b,
})
else
local uri = guide.getUri(source)
local infer1 = vm.getInfer(source[1])
local infer2 = vm.getInfer(source[2])
if (
infer1:hasType(uri, 'integer')
or infer1:hasType(uri, 'number')
or infer1:hasType(uri, 'string')
)
and (
infer2:hasType(uri, 'integer')
or infer2:hasType(uri, 'number')
or infer2:hasType(uri, 'string')
) then
vm.setNode(source, vm.declareGlobal('type', 'string'))
return
end
local node = vm.runOperator(binaryMap[source.op.type], source[1], source[2])
if not node then
node = vm.runOperator(binaryMap[source.op.type], source[2], source[1])
end
if node then
vm.setNode(source, node)
end
end
end)
: case '>'
: case '<'
: case '>='
: case '<='
: call(function (source)
local a = vm.getNumber(source[1])
local b = vm.getNumber(source[2])
if a and b then
local op = source.op.type
local result = op == '>' and a > b
or op == '<' and a < b
or op == '>=' and a >= b
or op == '<=' and a <= b
---@diagnostic disable-next-line: missing-fields
vm.setNode(source, {
type = 'boolean',
start = source.start,
finish = source.finish,
parent = source,
[1] =result,
})
else
vm.setNode(source, vm.declareGlobal('type', 'boolean'))
end
end)

View File

@@ -0,0 +1,20 @@
local files = require 'files'
local global = require 'vm.global'
local variable = require 'vm.variable'
---@async
files.watch(function (ev, uri)
if ev == 'update' then
global.dropUri(uri)
end
if ev == 'remove' then
global.dropUri(uri)
end
if ev == 'compile' then
local state = files.getLastState(uri)
if state then
global.compileAst(state.ast)
variable.compileAst(state.ast)
end
end
end)

View File

@@ -0,0 +1,328 @@
---@class vm
local vm = require 'vm.vm'
local util = require 'utility'
local guide = require 'parser.guide'
local files = require 'files'
local await = require 'await'
local progress = require 'progress'
local lang = require 'language'
local simpleSwitch
simpleSwitch = util.switch()
: case 'goto'
: call(function (source, pushResult)
if source.node then
simpleSwitch('label', source.node, pushResult)
pushResult(source.node)
end
end)
: case 'label'
: call(function (source, pushResult)
pushResult(source)
if source.ref then
for _, ref in ipairs(source.ref) do
pushResult(ref)
end
end
end)
---@async
local function searchInAllFiles(suri, searcher, notify)
await.delay()
searcher(suri)
await.delay()
local uris = {}
for uri in files.eachFile(suri) do
if not vm.isMetaFile(uri)
and suri ~= uri then
uris[#uris+1] = uri
end
end
local loading <close> = progress.create(suri, lang.script.WINDOW_SEARCHING_IN_FILES, 1)
local cancelled
loading:onCancel(function ()
cancelled = true
end)
for i, uri in ipairs(uris) do
if notify then
local continue = notify(uri)
if continue == false then
break
end
end
loading:setMessage(('%03d/%03d'):format(i, #uris))
loading:setPercentage(i / #uris * 100)
await.delay()
if cancelled then
break
end
searcher(uri)
end
end
---@async
local function searchWord(source, pushResult, defMap, fileNotify)
local key = guide.getKeyName(source)
if not key then
return
end
local global = vm.getGlobalNode(source)
---@param src parser.object
local function checkDef(src)
for _, def in ipairs(vm.getDefs(src)) do
if defMap[def] then
pushResult(src)
return
end
end
end
---@async
local function findWord(uri)
local text = files.getText(uri)
if not text then
return
end
if not text:find(key, 1, true) then
return
end
local state = files.getState(uri)
if not state then
return
end
if global then
local globalName = global:asKeyName()
---@async
guide.eachSourceTypes(state.ast, {'getglobal', 'setglobal', 'setfield', 'getfield', 'setmethod', 'getmethod', 'setindex', 'getindex', 'doc.type.name', 'doc.class.name', 'doc.alias.name', 'doc.extends.name'}, function (src)
local myGlobal = vm.getGlobalNode(src)
if myGlobal and myGlobal:asKeyName() == globalName then
pushResult(src)
await.delay()
end
end)
end
---@async
guide.eachSourceTypes(state.ast, {'getfield', 'setfield', 'tablefield'}, function (src)
if src.field and src.field[1] == key then
checkDef(src)
await.delay()
end
end)
---@async
guide.eachSourceTypes(state.ast, {'getmethod', 'setmethod'}, function (src)
if src.method and src.method[1] == key then
checkDef(src)
await.delay()
end
end)
---@async
guide.eachSourceTypes(state.ast, {'getindex', 'setindex'}, function (src)
if src.index and src.index.type == 'string' and src.index[1] == key then
checkDef(src)
await.delay()
end
end)
end
searchInAllFiles(guide.getUri(source), findWord, fileNotify)
end
---@async
local function searchFunction(source, pushResult, defMap, fileNotify)
---@param src parser.object
local function checkDef(src)
for _, def in ipairs(vm.getDefs(src)) do
if defMap[def] then
pushResult(src)
return
end
end
end
---@async
local function findCall(uri)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (src)
checkDef(src.node)
await.delay()
end)
end
searchInAllFiles(guide.getUri(source), findCall, fileNotify)
end
local searchByParentNode
local nodeSwitch = util.switch()
: case 'field'
: case 'method'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchByParentNode(source.parent, pushResult, defMap, fileNotify)
end)
: case 'getfield'
: case 'setfield'
: case 'getmethod'
: case 'setmethod'
: case 'getindex'
: case 'setindex'
---@async
: call(function (source, pushResult, defMap, fileNotify)
local key = guide.getKeyName(source)
if type(key) ~= 'string' then
return
end
searchWord(source, pushResult, defMap, fileNotify)
end)
: case 'tablefield'
: case 'tableindex'
: case 'doc.field.name'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchWord(source, pushResult, defMap, fileNotify)
end)
: case 'setglobal'
: case 'getglobal'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchWord(source, pushResult, defMap, fileNotify)
end)
: case 'doc.alias.name'
: case 'doc.class.name'
: case 'doc.enum.name'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchWord(source.parent, pushResult, defMap, fileNotify)
end)
: case 'doc.alias'
: case 'doc.class'
: case 'doc.enum'
: case 'doc.type.name'
: case 'doc.extends.name'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchWord(source, pushResult, defMap, fileNotify)
end)
: case 'function'
: case 'doc.type.function'
---@async
: call(function (source, pushResult, defMap, fileNotify)
searchFunction(source, pushResult, defMap, fileNotify)
end)
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchBySimple(source, pushResult)
simpleSwitch(source.type, source, pushResult)
end
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchByLocalID(source, pushResult)
local sourceSets = vm.getVariableSets(source)
if sourceSets then
for _, src in ipairs(sourceSets) do
pushResult(src)
end
end
local sourceGets = vm.getVariableGets(source)
if sourceGets then
for _, src in ipairs(sourceGets) do
pushResult(src)
end
end
end
---@async
---@param source parser.object
---@param pushResult fun(src: parser.object)
---@param fileNotify? fun(uri: uri): boolean
function searchByParentNode(source, pushResult, defMap, fileNotify)
nodeSwitch(source.type, source, pushResult, defMap, fileNotify)
end
local function searchByGlobal(source, pushResult)
if source.type == 'field'
or source.type == 'method'
or source.type == 'doc.class.name'
or source.type == 'doc.alias.name' then
source = source.parent
end
local node = vm.getGlobalNode(source)
if not node then
return
end
local uri = guide.getUri(source)
for _, set in ipairs(node:getSets(uri)) do
pushResult(set)
end
end
local function searchByDef(source, pushResult)
local defMap = {}
if source.type == 'function'
or source.type == 'doc.type.function' then
defMap[source] = true
return defMap
end
if source.type == 'field'
or source.type == 'method' then
source = source.parent
end
if source.type == 'doc.field.name' then
source = source.parent
end
defMap[source] = true
local defs = vm.getDefs(source)
for _, def in ipairs(defs) do
pushResult(def)
if not guide.isLiteral(def)
and def.type ~= 'doc.alias'
and def.type ~= 'doc.class'
and def.type ~= 'doc.enum' then
defMap[def] = true
end
end
return defMap
end
---@async
---@param source parser.object
---@param fileNotify? fun(uri: uri): boolean
function vm.getRefs(source, fileNotify)
local results = {}
local mark = {}
local hasLocal
local function pushResult(src)
if src.type == 'local' then
if hasLocal then
return
end
hasLocal = true
end
if not mark[src] then
mark[src] = true
results[#results+1] = src
end
end
searchBySimple(source, pushResult)
searchByLocalID(source, pushResult)
searchByGlobal(source, pushResult)
local defMap = searchByDef(source, pushResult)
searchByParentNode(source, pushResult, defMap, fileNotify)
return results
end

View File

@@ -0,0 +1,343 @@
local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
---@class vm.sign
---@field parent parser.object
---@field signList vm.node[]
---@field docGeneric parser.object[]
local mt = {}
mt.__index = mt
mt.type = 'sign'
---@param node vm.node
function mt:addSign(node)
self.signList[#self.signList+1] = node
end
---@param doc parser.object
function mt:addDocGeneric(doc)
self.docGeneric[#self.docGeneric+1] = doc
end
---@param uri uri
---@param args parser.object
---@return table<string, vm.node>?
function mt:resolve(uri, args)
if not args then
return nil
end
---@type table<string, vm.node>
local resolved = {}
---@param object vm.node|vm.node.object
---@param node vm.node
local function resolve(object, node)
if object.type == 'vm.node' then
for o in object:eachObject() do
resolve(o, node)
end
return
end
if object.type == 'doc.type' then
---@cast object parser.object
resolve(vm.compileNode(object), node)
return
end
if object.type == 'doc.generic.name' then
---@type string
local key = object[1]
if object.literal then
-- 'number' -> `T`
for n in node:eachObject() do
if n.type == 'string' then
---@cast n parser.object
local type = vm.declareGlobal('type', object.pattern and object.pattern:format(n[1]) or n[1], guide.getUri(n))
resolved[key] = vm.createNode(type, resolved[key])
end
end
else
-- number -> T
for n in node:eachObject() do
if n.type ~= 'doc.generic.name'
and n.type ~= 'generic' then
if resolved[key] then
resolved[key]:merge(n)
else
resolved[key] = vm.createNode(n)
end
end
end
if resolved[key] and node:isOptional() then
resolved[key]:addOptional()
end
end
return
end
if object.type == 'doc.type.array' then
for n in node:eachObject() do
if n.type == 'doc.type.array' then
-- number[] -> T[]
resolve(object.node, vm.compileNode(n.node))
end
if n.type == 'doc.type.table' then
-- { [integer]: number } -> T[]
local tvalueNode = vm.getTableValue(uri, node, 'integer', true)
if tvalueNode then
resolve(object.node, tvalueNode)
end
end
if n.type == 'global' and n.cate == 'type' then
-- ---@field [integer]: number -> T[]
---@cast n vm.global
vm.getClassFields(uri, n, vm.declareGlobal('type', 'integer'), function (field)
resolve(object.node, vm.compileNode(field.extends))
end)
end
if n.type == 'table' and #n >= 1 then
-- { x } / { ... } -> T[]
resolve(object.node, vm.compileNode(n[1]))
end
end
return
end
if object.type == 'doc.type.table' then
for _, ufield in ipairs(object.fields) do
local ufieldNode = vm.compileNode(ufield.name)
local uvalueNode = vm.compileNode(ufield.extends)
local firstField = ufieldNode:get(1)
local firstValue = uvalueNode:get(1)
if not firstField or not firstValue then
goto CONTINUE
end
if firstField.type == 'doc.generic.name' and firstValue.type == 'doc.generic.name' then
-- { [number]: number} -> { [K]: V }
local tfieldNode = vm.getTableKey(uri, node, 'any', true)
local tvalueNode = vm.getTableValue(uri, node, 'any', true)
if tfieldNode then
resolve(firstField, tfieldNode)
end
if tvalueNode then
resolve(firstValue, tvalueNode)
end
else
if ufieldNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [K]: number }
local tnode = vm.getTableKey(uri, node, uvalueNode, true)
if tnode then
resolve(firstField, tnode)
end
elseif uvalueNode:get(1).type == 'doc.generic.name' then
-- { [number]: number}|number[] -> { [number]: V }
local tnode = vm.getTableValue(uri, node, ufieldNode, true)
if tnode then
resolve(firstValue, tnode)
end
end
end
::CONTINUE::
end
return
end
if object.type == 'doc.type.function' then
for i, arg in ipairs(object.args) do
if arg.extends then
for n in node:eachObject() do
if n.type == 'function'
or n.type == 'doc.type.function' then
---@cast n parser.object
local farg = n.args and n.args[i]
if farg then
resolve(arg.extends, vm.compileNode(farg))
end
end
end
end
end
for i, ret in ipairs(object.returns) do
for n in node:eachObject() do
if n.type == 'function'
or n.type == 'doc.type.function' then
---@cast n parser.object
local fret = vm.getReturnOfFunction(n, i)
if fret then
resolve(ret, vm.compileNode(fret))
end
end
end
end
return
end
end
---@param sign vm.node
---@return table<string, true>
---@return table<string, true>
local function getSignInfo(sign)
local knownTypes = {}
local genericsNames = {}
for obj in sign:eachObject() do
if obj.type == 'doc.generic.name' then
genericsNames[obj[1]] = true
goto CONTINUE
end
if obj.type == 'doc.type.table'
or obj.type == 'doc.type.function'
or obj.type == 'doc.type.array' then
---@cast obj parser.object
local hasGeneric
guide.eachSourceType(obj, 'doc.generic.name', function (src)
hasGeneric = true
genericsNames[src[1]] = true
end)
if hasGeneric then
goto CONTINUE
end
end
if obj.type == 'variable'
or obj.type == 'local' then
goto CONTINUE
end
local view = vm.getInfer(obj):view(uri)
if view then
knownTypes[view] = true
end
::CONTINUE::
end
return knownTypes, genericsNames
end
-- remove un-generic type
---@param argNode vm.node
---@param sign vm.node
---@param knownTypes table<string, true>
---@return vm.node
local function buildArgNode(argNode, sign, knownTypes)
local newArgNode = vm.createNode()
local needRemoveNil = sign:hasFalsy()
for n in argNode:eachObject() do
if needRemoveNil then
if n.type == 'nil' then
goto CONTINUE
end
if n.type == 'global' and n.cate == 'type' and n.name == 'nil' then
goto CONTINUE
end
end
local view = vm.getInfer(n):view(uri)
if knownTypes[view] then
goto CONTINUE
end
newArgNode:merge(n)
::CONTINUE::
end
if not needRemoveNil and argNode:isOptional() then
newArgNode:addOptional()
end
return newArgNode
end
---@param genericNames table<string, true>
local function isAllResolved(genericNames)
for n in pairs(genericNames) do
if not resolved[n] then
return false
end
end
return true
end
for i, arg in ipairs(args) do
local sign = self.signList[i]
if not sign then
break
end
local argNode = vm.compileNode(arg)
local knownTypes, genericNames = getSignInfo(sign)
if not isAllResolved(genericNames) then
local newArgNode = buildArgNode(argNode, sign, knownTypes)
resolve(sign, newArgNode)
end
end
return resolved
end
---@return vm.sign
function vm.createSign()
local genericMgr = setmetatable({
signList = {},
docGeneric = {},
}, mt)
return genericMgr
end
---@class parser.object
---@field package _sign vm.sign|false|nil
---@param source parser.object
---@param sign vm.sign
function vm.setSign(source, sign)
source._sign = sign
end
---@param source parser.object
---@return vm.sign?
function vm.getSign(source)
if source._sign ~= nil then
return source._sign or nil
end
source._sign = false
if source.type == 'function' then
if not source.bindDocs then
return nil
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.generic' then
if not source._sign then
source._sign = vm.createSign()
end
source._sign:addDocGeneric(doc)
end
end
if not source._sign then
return nil
end
if source.args then
for _, arg in ipairs(source.args) do
local argNode = vm.compileNode(arg)
if arg.optional then
argNode:addOptional()
end
source._sign:addSign(argNode)
end
end
end
if source.type == 'doc.type.function'
or source.type == 'doc.type.table'
or source.type == 'doc.type.array' then
local hasGeneric
guide.eachSourceType(source, 'doc.generic.name', function (_)
hasGeneric = true
end)
if not hasGeneric then
return nil
end
source._sign = vm.createSign()
if source.type == 'doc.type.function' then
for _, arg in ipairs(source.args) do
if arg.extends then
local argNode = vm.compileNode(arg.extends)
if arg.optional then
argNode:addOptional()
end
source._sign:addSign(argNode)
else
source._sign:addSign(vm.createNode())
end
end
end
end
return source._sign or nil
end

View File

@@ -0,0 +1,947 @@
---@class vm
local vm = require 'vm.vm'
local guide = require 'parser.guide'
local util = require 'utility'
---@class parser.object
---@field package _tracer? vm.tracer
---@field package _casts? parser.object[]
---@alias tracer.mode 'local' | 'global'
---@class vm.tracer
---@field mode tracer.mode
---@field name string
---@field source parser.object | vm.variable
---@field assigns (parser.object | vm.variable)[]
---@field assignMap table<parser.object, true>
---@field getMap table<parser.object, true>
---@field careMap table<parser.object, true>
---@field mark table<parser.object, true>
---@field casts parser.object[]
---@field nodes table<parser.object, vm.node|false>
---@field main parser.object
---@field uri uri
---@field castIndex integer?
local mt = {}
mt.__index = mt
mt.fastCalc = true
---@return parser.object[]
function mt:getCasts()
local root = guide.getRoot(self.main)
if not root._casts then
root._casts = {}
local docs = root.docs
for _, doc in ipairs(docs) do
if doc.type == 'doc.cast' and doc.name then
root._casts[#root._casts+1] = doc
end
end
end
return root._casts
end
---@param obj parser.object
function mt:collectAssign(obj)
while true do
local block = guide.getParentBlock(obj)
if not block then
return
end
obj = block
if self.assignMap[obj] then
return
end
if obj == self.main then
return
end
self.assignMap[obj] = true
self.assigns[#self.assigns+1] = obj
end
end
---@param obj parser.object
function mt:collectCare(obj)
while true do
if self.careMap[obj] then
return
end
if obj == self.main then
return
end
if not obj then
return
end
self.careMap[obj] = true
if self.fastCalc then
if obj.type == 'if'
or obj.type == 'while'
or obj.type == 'binary' then
self.fastCalc = false
end
if obj.type == 'call' and obj.node then
if obj.node.special == 'assert'
or obj.node.special == 'type' then
self.fastCalc = false
end
end
end
obj = obj.parent
end
end
function mt:collectLocal()
local startPos = self.source.base.start
local finishPos = 0
local variable = self.source
if variable.base.type ~= 'local'
and variable.base.type ~= 'self' then
self.assigns[#self.assigns+1] = variable
self.assignMap[self.source] = true
end
for _, set in ipairs(variable.sets) do
self.assigns[#self.assigns+1] = set
self.assignMap[set] = true
self:collectCare(set)
if set.finish > finishPos then
finishPos = set.finish
end
end
for _, get in ipairs(variable.gets) do
self:collectCare(get)
self.getMap[get] = true
if get.finish > finishPos then
finishPos = get.finish
end
end
local casts = self:getCasts()
for _, cast in ipairs(casts) do
if cast.name[1] == self.name
and cast.start > startPos
and cast.finish < finishPos
and vm.getCastTargetHead(cast) == variable.base then
self.casts[#self.casts+1] = cast
end
end
if #self.casts > 0 then
self.fastCalc = false
end
end
function mt:collectGlobal()
self.assigns[#self.assigns+1] = self.source
self.assignMap[self.source] = true
local uri = guide.getUri(self.source)
local global = self.source.global
local link = global.links[uri]
for _, set in ipairs(link.sets) do
self.assigns[#self.assigns+1] = set
self.assignMap[set] = true
self:collectCare(set)
end
for _, get in ipairs(link.gets) do
self:collectCare(get)
self.getMap[get] = true
end
local casts = self:getCasts()
for _, cast in ipairs(casts) do
if cast.name[1] == self.name then
local castTarget = vm.getCastTargetHead(cast)
if castTarget and castTarget.type == 'global' then
self.casts[#self.casts+1] = cast
end
end
end
if #self.casts > 0 then
self.fastCalc = false
end
end
---@param start integer
---@param finish integer
---@return parser.object?
function mt:getLastAssign(start, finish)
local lastAssign
for _, assign in ipairs(self.assigns) do
local obj
if assign.type == 'variable' then
---@cast assign vm.variable
obj = assign.base
else
---@cast assign parser.object
obj = assign
end
if obj.start < start then
goto CONTINUE
end
if (obj.effect or obj.range or obj.start) >= finish then
break
end
local objBlock = guide.getTopBlock(obj)
if not objBlock then
break
end
if objBlock.start <= finish
and objBlock.finish >= finish then
lastAssign = obj
end
::CONTINUE::
end
return lastAssign
end
---@param pos integer
function mt:resetCastsIndex(pos)
for i = 1, #self.casts do
local cast = self.casts[i]
if cast.start > pos then
self.castIndex = i
return
end
end
self.castIndex = nil
end
---@param pos integer
---@param node vm.node
---@return vm.node
function mt:fastWardCasts(pos, node)
if not self.castIndex then
return node
end
for i = self.castIndex, #self.casts do
local action = self.casts[i]
if action.start > pos then
return node
end
node = node:copy()
for _, cast in ipairs(action.casts) do
if cast.mode == '+' then
if cast.optional then
node:addOptional()
end
if cast.extends then
node:merge(vm.compileNode(cast.extends))
end
elseif cast.mode == '-' then
if cast.optional then
node:removeOptional()
end
if cast.extends then
node:removeNode(vm.compileNode(cast.extends))
end
else
if cast.extends then
node:clear()
node:merge(vm.compileNode(cast.extends))
end
end
end
end
self.castIndex = self.castIndex + 1
return node
end
--- Return types of source which have a field with the value of literal.
--- @param uri uri
--- @param source parser.object
--- @param fieldName string
--- @param literal parser.object
--- @return [string, boolean][]?
local function getNodeTypesWithLiteralField(uri, source, fieldName, literal)
local loc = vm.getVariable(source)
if not loc then
return
end
-- Literal must has a value
if literal[1] == nil then
return
end
local tys
for _, c in ipairs(vm.compileNode(loc)) do
if c.cate == 'type' then
for _, set in ipairs(c:getSets(uri)) do
if set.type == 'doc.class' then
for _, f in ipairs(set.fields) do
if f.field[1] == fieldName then
for _, t in ipairs(f.extends.types) do
if guide.isLiteral(t) and t[1] ~= nil and t[1] == literal[1] then
tys = tys or {}
table.insert(tys, { set.class[1], #f.extends.types > 1 })
break
end
end
break
end
end
end
end
end
end
return tys
end
local lookIntoChild = util.switch()
: case 'getlocal'
: case 'getglobal'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
if tracer.getMap[action] then
tracer.nodes[action] = topNode
if outNode then
topNode = topNode:copy():setTruthy()
outNode = outNode:copy():setFalsy()
end
end
return topNode, outNode
end)
: case 'repeat'
: case 'loop'
: case 'for'
: case 'do'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
if action.type == 'loop' then
tracer:lookIntoChild(action.init, topNode)
tracer:lookIntoChild(action.max, topNode)
end
if action[1] then
tracer:lookIntoBlock(action, action.bstart, topNode:copy())
local lastAssign = tracer:getLastAssign(action.start, action.finish)
if lastAssign then
tracer:getNode(lastAssign)
end
if tracer.nodes[action] then
topNode = tracer.nodes[action]:copy()
end
end
if action.type == 'repeat' then
tracer:lookIntoChild(action.filter, topNode)
end
return topNode, outNode
end)
: case 'in'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.exps, topNode)
if action[1] then
tracer:lookIntoBlock(action, action.bstart, topNode:copy())
local lastAssign = tracer:getLastAssign(action.start, action.finish)
if lastAssign then
tracer:getNode(lastAssign)
end
if tracer.nodes[action] then
topNode = tracer.nodes[action]:copy()
end
end
return topNode, outNode
end)
: case 'while'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
local blockNode, mainNode
if action.filter then
blockNode, mainNode = tracer:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
else
blockNode = topNode:copy()
mainNode = topNode:copy()
end
if action[1] then
tracer:lookIntoBlock(action, action.bstart, blockNode:copy())
local lastAssign = tracer:getLastAssign(action.start, action.finish)
if lastAssign then
tracer:getNode(lastAssign)
end
if tracer.nodes[action] then
topNode = mainNode:merge(tracer.nodes[action])
end
end
if action.filter then
-- look into filter again
guide.eachSource(action.filter, function (src)
tracer.mark[src] = nil
end)
blockNode, topNode = tracer:lookIntoChild(action.filter, topNode:copy(), topNode:copy())
end
return topNode, outNode
end)
: case 'if'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
local hasElse
local mainNode = topNode:copy()
local blockNodes = {}
for _, subBlock in ipairs(action) do
tracer:resetCastsIndex(subBlock.start)
local blockNode = mainNode:copy()
if subBlock.filter then
blockNode, mainNode = tracer:lookIntoChild(subBlock.filter, blockNode, mainNode)
else
hasElse = true
mainNode:clear()
end
local mergedNode
if subBlock[1] then
tracer:lookIntoBlock(subBlock, subBlock.bstart, blockNode:copy())
local neverReturn = subBlock.hasReturn
or subBlock.hasGoTo
or subBlock.hasBreak
or subBlock.hasExit
if neverReturn then
mergedNode = true
else
local lastAssign = tracer:getLastAssign(subBlock.start, subBlock.finish)
if lastAssign then
tracer:getNode(lastAssign)
end
if tracer.nodes[subBlock] then
blockNodes[#blockNodes+1] = tracer.nodes[subBlock]
mergedNode = true
end
end
end
if not mergedNode then
blockNodes[#blockNodes+1] = blockNode
end
end
if not hasElse and not topNode:hasKnownType() then
mainNode:merge(vm.declareGlobal('type', 'unknown'))
end
for _, blockNode in ipairs(blockNodes) do
mainNode:merge(blockNode)
end
topNode = mainNode
return topNode, outNode
end)
: case 'getfield'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.field, topNode)
if tracer.getMap[action] then
tracer.nodes[action] = topNode
if outNode then
topNode = topNode:copy():setTruthy()
outNode = outNode:copy():setFalsy()
end
end
return topNode, outNode
end)
: case 'getmethod'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.method, topNode)
if tracer.getMap[action] then
tracer.nodes[action] = topNode
if outNode then
topNode = topNode:copy():setTruthy()
outNode = outNode:copy():setFalsy()
end
end
return topNode, outNode
end)
: case 'getindex'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.index, topNode)
if tracer.getMap[action] then
tracer.nodes[action] = topNode
if outNode then
topNode = topNode:copy():setTruthy()
outNode = outNode:copy():setFalsy()
end
end
return topNode, outNode
end)
: case 'setfield'
: case 'setmethod'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.value, topNode)
return topNode, outNode
end)
: case 'setglobal'
: case 'setlocal'
: case 'tablefield'
: case 'tableexp'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.value, topNode)
return topNode, outNode
end)
: case 'setindex'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.index, topNode)
tracer:lookIntoChild(action.value, topNode)
return topNode, outNode
end)
: case 'tableindex'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.index, topNode)
tracer:lookIntoChild(action.value, topNode)
return topNode, outNode
end)
: case 'local'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.value, topNode)
-- special treat for `local tp = type(x)`
if action.value
and action.ref
and action.value.type == 'select' then
local index = action.value.sindex
local call = action.value.vararg
if index == 1
and call.type == 'call'
and call.node
and call.node.special == 'type'
and call.args then
local getVar = call.args[1]
if getVar
and tracer.getMap[getVar] then
for _, ref in ipairs(action.ref) do
tracer:collectCare(ref)
end
end
end
end
return topNode, outNode
end)
: case 'return'
: case 'table'
: case 'callargs'
: case 'list'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
for _, ret in ipairs(action) do
tracer:lookIntoChild(ret, topNode:copy())
end
return topNode, outNode
end)
: case 'select'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoChild(action.vararg, topNode)
return topNode, outNode
end)
: case 'function'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
tracer:lookIntoBlock(action, action.bstart, topNode:copy())
return topNode, outNode
end)
: case 'paren'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
topNode, outNode = tracer:lookIntoChild(action.exp, topNode, outNode)
return topNode, outNode
end)
: case 'call'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
if action.node.special == 'assert' and action.args and action.args[1] then
for i = 2, #action.args do
tracer:lookIntoChild(action.args[i], topNode, topNode:copy())
end
topNode = tracer:lookIntoChild(action.args[1], topNode:copy(), topNode:copy())
end
tracer:lookIntoChild(action.node, topNode)
tracer:lookIntoChild(action.args, topNode)
return topNode, outNode
end)
: case 'binary'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
if not action[1] or not action[2] then
tracer:lookIntoChild(action[1], topNode)
tracer:lookIntoChild(action[2], topNode)
return topNode, outNode
end
if action.op.type == 'and' then
topNode = tracer:lookIntoChild(action[1], topNode, topNode:copy())
topNode = tracer:lookIntoChild(action[2], topNode, topNode:copy())
elseif action.op.type == 'or' then
outNode = outNode or topNode:copy()
local topNode1, outNode1 = tracer:lookIntoChild(action[1], topNode, outNode)
local topNode2, outNode2 = tracer:lookIntoChild(action[2], outNode1, outNode1:copy())
topNode = vm.createNode(topNode1, topNode2)
outNode = outNode2:copy()
elseif action.op.type == '=='
or action.op.type == '~=' then
local handler, checker
for i = 1, 2 do
if guide.isLiteral(action[i]) then
checker = action[i]
handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1`
end
end
if not handler then
tracer:lookIntoChild(action[1], topNode)
tracer:lookIntoChild(action[2], topNode)
return topNode, outNode
end
if tracer.getMap[handler] then
-- if x == y then
topNode = tracer:lookIntoChild(handler, topNode, outNode)
local checkerNode = vm.compileNode(checker)
local checkerName = vm.getNodeName(checker)
if checkerName then
topNode = topNode:copy()
if action.op.type == '==' then
topNode:narrow(tracer.uri, checkerName)
if outNode then
outNode:removeNode(checkerNode)
end
else
topNode:removeNode(checkerNode)
if outNode then
outNode:narrow(tracer.uri, checkerName)
end
end
end
elseif handler.type == 'getfield'
and handler.node.type == 'getlocal' then
local tys
if handler.field then
tys = getNodeTypesWithLiteralField(tracer.uri, handler.node, handler.field[1], checker)
end
-- TODO: handle more types
if tys and #tys == 1 then
-- If the type is in a union (e.g. 'lit' | foo), then the type
-- cannot be removed from the node.
local ty, tyInUnion = tys[1][1], tys[1][2]
topNode = topNode:copy()
if action.op.type == '==' then
topNode:narrow(tracer.uri, ty)
if not tyInUnion and outNode then
outNode:remove(ty)
end
else
if not tyInUnion then
topNode:remove(ty)
end
if outNode then
outNode:narrow(tracer.uri, ty)
end
end
end
elseif handler.type == 'call'
and checker.type == 'string'
and handler.node.special == 'type'
and handler.args
and handler.args[1]
and tracer.getMap[handler.args[1]] then
-- if type(x) == 'string' then
tracer:lookIntoChild(handler, topNode)
topNode = topNode:copy()
if action.op.type == '==' then
topNode:narrow(tracer.uri, checker[1])
if outNode then
outNode:remove(checker[1])
end
else
topNode:remove(checker[1])
if outNode then
outNode:narrow(tracer.uri, checker[1])
end
end
elseif handler.type == 'getlocal'
and checker.type == 'string' then
-- `local tp = type(x);if tp == 'string' then`
local nodeValue = vm.getObjectValue(handler.node)
if nodeValue
and nodeValue.type == 'select'
and nodeValue.sindex == 1 then
local call = nodeValue.vararg
if call
and call.type == 'call'
and call.node.special == 'type'
and call.args
and tracer.getMap[call.args[1]] then
if action.op.type == '==' then
topNode:narrow(tracer.uri, checker[1])
if outNode then
outNode:remove(checker[1])
end
else
topNode:remove(checker[1])
if outNode then
outNode:narrow(tracer.uri, checker[1])
end
end
end
end
end
end
tracer:lookIntoChild(action[1], topNode)
tracer:lookIntoChild(action[2], topNode)
return topNode, outNode
end)
: case 'unary'
---@param tracer vm.tracer
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
: call(function (tracer, action, topNode, outNode)
if not action[1] then
tracer:lookIntoChild(action[1], topNode)
return topNode, outNode
end
if action.op.type == 'not' then
outNode = outNode or topNode:copy()
outNode, topNode = tracer:lookIntoChild(action[1], topNode, outNode)
outNode = outNode:copy()
end
tracer:lookIntoChild(action[1], topNode)
return topNode, outNode
end)
---@param action parser.object
---@param topNode vm.node
---@param outNode? vm.node
---@return vm.node topNode
---@return vm.node outNode
function mt:lookIntoChild(action, topNode, outNode)
if not self.careMap[action]
or self.mark[action] then
return topNode, outNode or topNode
end
self.mark[action] = true
topNode = self:fastWardCasts(action.start, topNode)
topNode, outNode = lookIntoChild(action.type, self, action, topNode, outNode)
return topNode, outNode or topNode
end
---@param block parser.object
---@param start integer
---@param node vm.node
function mt:lookIntoBlock(block, start, node)
self:resetCastsIndex(start)
for _, action in ipairs(block) do
if (action.effect or action.start) < start then
goto CONTINUE
end
if self.careMap[action] then
node = self:lookIntoChild(action, node)
if action.type == 'do'
or action.type == 'loop'
or action.type == 'in'
or action.type == 'repeat' then
return
end
end
if action.finish > start and self.assignMap[action] then
return
end
::CONTINUE::
end
self.nodes[block] = node
if block.type == 'repeat' then
self:lookIntoChild(block.filter, node)
end
if block.type == 'do'
or block.type == 'loop'
or block.type == 'in'
or block.type == 'repeat' then
self:lookIntoBlock(block.parent, block.finish, node)
end
end
---@param source parser.object
function mt:calcNode(source)
if self.getMap[source] then
local lastAssign = self:getLastAssign(0, source.finish)
if not lastAssign then
return
end
if self.fastCalc then
self.nodes[source] = vm.compileNode(lastAssign)
return
end
self:calcNode(lastAssign)
return
end
if self.assignMap[source] then
local node = vm.compileNode(source)
self.nodes[source] = node
local parentBlock = guide.getParentBlock(source)
if parentBlock then
self:lookIntoBlock(parentBlock, source.finish, node)
end
return
end
end
---@param source parser.object
---@return vm.node?
function mt:getNode(source)
local cache = self.nodes[source]
if cache ~= nil then
return cache or nil
end
if source == self.main then
self.nodes[source] = false
return nil
end
self.nodes[source] = false
self:calcNode(source)
return self.nodes[source] or nil
end
---@class vm.node
---@field package _tracer vm.tracer
---@param mode tracer.mode
---@param source parser.object | vm.variable
---@param name string
---@return vm.tracer?
local function createTracer(mode, source, name)
local node = vm.compileNode(source)
local tracer = node._tracer
if tracer then
return tracer
end
local main
if source.type == 'variable' then
---@cast source vm.variable
main = guide.getParentBlock(source.base)
else
---@cast source parser.object
main = guide.getParentBlock(source)
end
if not main then
return nil
end
tracer = setmetatable({
source = source,
mode = mode,
name = name,
assigns = {},
assignMap = {},
getMap = {},
careMap = {},
mark = {},
casts = {},
nodes = {},
main = main,
uri = guide.getUri(main),
}, mt)
node._tracer = tracer
if tracer.mode == 'local' then
tracer:collectLocal()
else
tracer:collectGlobal()
end
return tracer
end
---@param source parser.object
---@return vm.node?
function vm.traceNode(source)
local mode, base, name
if vm.getGlobalNode(source) then
base = vm.getGlobalBase(source)
if not base then
return nil
end
mode = 'global'
name = base.global:getCodeName()
else
base = vm.getVariable(source)
if not base then
return nil
end
name = base:getCodeName()
mode = 'local'
end
local tracer = createTracer(mode, base, name)
if not tracer then
return nil
end
local node = tracer:getNode(source)
return node
end

View File

@@ -0,0 +1,915 @@
---@class vm
local vm = require 'vm.vm'
local guide = require 'parser.guide'
local config = require 'config.config'
local util = require 'utility'
local lang = require 'language'
---@class vm.ANY
vm.ANY = {'<VM.ANY>'}
---@class vm.ANYDOC
vm.ANYDOC = {'<VM.ANYDOC>'}
---@alias typecheck.err vm.node.object|string|vm.node
---@param object vm.node.object
---@return string?
function vm.getNodeName(object)
if object.type == 'global' and object.cate == 'type' then
---@cast object vm.global
return object.name
end
if object.type == 'nil'
or object.type == 'boolean'
or object.type == 'number'
or object.type == 'string'
or object.type == 'table'
or object.type == 'function'
or object.type == 'integer' then
return object.type
end
if object.type == 'doc.type.boolean' then
return 'boolean'
end
if object.type == 'doc.type.integer' then
return 'integer'
end
if object.type == 'doc.type.function' then
return 'function'
end
if object.type == 'doc.type.table' then
return 'table'
end
if object.type == 'doc.type.array' then
return 'table'
end
if object.type == 'doc.type.string' then
return 'string'
end
if object.type == 'doc.field.name' then
return 'string'
end
return nil
end
---@param parentName string
---@param child vm.node.object
---@param uri uri
---@param mark table
---@param errs? typecheck.err[]
---@return boolean?
local function checkParentEnum(parentName, child, uri, mark, errs)
local parentClass = vm.getGlobal('type', parentName)
if not parentClass then
return nil
end
local enums
for _, set in ipairs(parentClass:getSets(uri)) do
if set.type == 'doc.enum' then
local denums = vm.getEnums(set)
if denums then
if enums then
enums = util.arrayMerge(enums, denums)
else
enums = util.arrayMerge({}, denums)
end
end
end
end
if not enums then
return nil
end
if child.type == 'global' then
---@cast child vm.global
for _, enum in ipairs(enums) do
if vm.isSubType(uri, child, vm.compileNode(enum), mark) then
return true
end
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_ENUM_GLOBAL_DISMATCH'
errs[#errs+1] = child
errs[#errs+1] = parentClass
end
return false
elseif child.type == 'generic' then
---@cast child vm.generic
if errs then
errs[#errs+1] = 'TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED'
errs[#errs+1] = child
end
return false
else
---@cast child parser.object
local childName = vm.getNodeName(child)
if childName == 'number'
or childName == 'integer'
or childName == 'boolean'
or childName == 'string' then
for _, enum in ipairs(enums) do
for nd in vm.compileNode(enum):eachObject() do
if childName == vm.getNodeName(nd) and nd[1] == child[1] then
return true
end
end
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_ENUM_LITERAL_DISMATCH'
errs[#errs+1] = child[1]
errs[#errs+1] = parentClass
end
return false
elseif childName == 'function'
or childName == 'table' then
for _, enum in ipairs(enums) do
for nd in vm.compileNode(enum):eachObject() do
if child == nd then
return true
end
end
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_ENUM_OBJECT_DISMATCH'
errs[#errs+1] = child
errs[#errs+1] = parentClass
end
return false
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_ENUM_NO_OBJECT'
errs[#errs+1] = child
end
return false
end
end
---@param childName string
---@param parent vm.node.object
---@param uri uri
---@param mark table
---@param errs? typecheck.err[]
---@return boolean?
local function checkChildEnum(childName, parent, uri, mark, errs)
if mark[childName] then
return
end
local childClass = vm.getGlobal('type', childName)
if not childClass then
return nil
end
local enums
for _, set in ipairs(childClass:getSets(uri)) do
if set.type == 'doc.enum' then
enums = vm.getEnums(set)
break
end
end
if not enums then
return nil
end
mark[childName] = true
for _, enum in ipairs(enums) do
if not vm.isSubType(uri, vm.compileNode(enum), parent, mark, errs) then
mark[childName] = nil
return false
end
end
mark[childName] = nil
return true
end
---@param parent vm.node.object
---@param child vm.node.object
---@param mark table
---@param errs? typecheck.err[]
---@return boolean
local function checkValue(parent, child, mark, errs)
if parent.type == 'doc.type.integer' then
if child.type == 'integer'
or child.type == 'doc.type.integer'
or child.type == 'number' then
if parent[1] ~= child[1] then
if errs then
errs[#errs+1] = 'TYPE_ERROR_INTEGER_DISMATCH'
errs[#errs+1] = child[1]
errs[#errs+1] = parent[1]
end
return false
end
end
return true
end
if parent.type == 'doc.type.string'
or parent.type == 'doc.field.name' then
if child.type == 'string'
or child.type == 'doc.type.string'
or child.type == 'doc.field.name' then
if parent[1] ~= child[1] then
if errs then
errs[#errs+1] = 'TYPE_ERROR_STRING_DISMATCH'
errs[#errs+1] = child[1]
errs[#errs+1] = parent[1]
end
return false
end
end
return true
end
if parent.type == 'doc.type.boolean' then
if child.type == 'boolean'
or child.type == 'doc.type.boolean' then
if parent[1] ~= child[1] then
if errs then
errs[#errs+1] = 'TYPE_ERROR_BOOLEAN_DISMATCH'
errs[#errs+1] = child[1]
errs[#errs+1] = parent[1]
end
return false
end
end
return true
end
if parent.type == 'doc.type.table' then
if child.type == 'doc.type.table' then
if child == parent then
return true
end
---@cast parent parser.object
---@cast child parser.object
local uri = guide.getUri(parent)
local tnode = vm.compileNode(child)
for _, pfield in ipairs(parent.fields) do
local knode = vm.compileNode(pfield.name)
local cvalues = vm.getTableValue(uri, tnode, knode, true)
if not cvalues then
if pfield.optional then
goto continue
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_TABLE_NO_FIELD'
errs[#errs+1] = pfield.name
end
return false
end
local pvalues = vm.compileNode(pfield.extends)
if vm.isSubType(uri, cvalues, pvalues, mark, errs) == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_TABLE_FIELD_DISMATCH'
errs[#errs+1] = pfield.name
errs[#errs+1] = cvalues
errs[#errs+1] = pvalues
end
return false
end
::continue::
end
end
return true
end
return true
end
---@param name string
---@param suri uri
---@return boolean
local function isAlias(name, suri)
local global = vm.getGlobal('type', name)
if not global then
return false
end
for _, set in ipairs(global:getSets(suri)) do
if set.type == 'doc.alias' then
return true
end
end
return false
end
local function checkTableShape(parent, child, uri, mark, errs)
local set = parent:getSets(uri)
local missedKeys = {}
local failedCheck
local myKeys
for _, def in ipairs(set) do
if not def.fields or #def.fields == 0 then
goto continue
end
if not myKeys then
myKeys = {}
for _, field in ipairs(child) do
local key = vm.getKeyName(field) or field.tindex
if key then
myKeys[key] = vm.compileNode(field)
end
end
end
for _, field in ipairs(def.fields) do
local key = vm.getKeyName(field)
if not key then
local fieldnode = vm.compileNode(field.field)[1]
if fieldnode and fieldnode.type == 'doc.type.integer' then
---@cast fieldnode parser.object
key = vm.getKeyName(fieldnode)
end
end
if not key then
goto continue
end
local ok
local nodeField = vm.compileNode(field)
if myKeys[key] then
ok = vm.isSubType(uri, myKeys[key], nodeField, mark, errs)
if ok == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH' -- error display can be greatly improved
errs[#errs+1] = myKeys[key]
errs[#errs+1] = nodeField
end
failedCheck = true
end
elseif not nodeField:isNullable() then
if type(key) == "number" then
missedKeys[#missedKeys+1] = ('`[%s]`'):format(key)
else
missedKeys[#missedKeys+1] = ('`%s`'):format(key)
end
failedCheck = true
end
end
::continue::
end
if errs and #missedKeys > 0 then
errs[#errs+1] = 'DIAG_MISSING_FIELDS'
errs[#errs+1] = parent
errs[#errs+1] = table.concat(missedKeys, ', ')
end
if failedCheck then
return false
end
return true
end
---@param uri uri
---@param child vm.node|string|vm.node.object
---@param parent vm.node|string|vm.node.object
---@param mark? table
---@param errs? typecheck.err[]
---@return boolean|nil
function vm.isSubType(uri, child, parent, mark, errs)
mark = mark or {}
if type(child) == 'string' then
local global = vm.getGlobal('type', child)
if not global then
return nil
end
child = global
elseif child.type == 'vm.node' then
if config.get(uri, 'Lua.type.weakUnionCheck') then
local hasKnownType = 0
local i = 0
for n in child:eachObject() do
i = i + 1
if i > 100 then
break
end
if vm.getNodeName(n) then
local res = vm.isSubType(uri, n, parent, mark, errs)
if res == true then
return true
elseif res == false then
hasKnownType = hasKnownType + 1
end
end
end
if hasKnownType > 0 then
if errs
and hasKnownType > 1
and #vm.getInfer(child):getSubViews(uri) > 1 then
errs[#errs+1] = 'TYPE_ERROR_CHILD_ALL_DISMATCH'
errs[#errs+1] = child
errs[#errs+1] = parent
end
return false
end
return true
else
local weakNil = config.get(uri, 'Lua.type.weakNilCheck')
local skipTable
local i = 0
for n in child:eachObject() do
i = i + 1
if i > 100 then
break
end
if skipTable == nil and n.type == "table" and parent.type == "vm.node" then -- skip table type check if child has class
---@cast parent vm.node
for _, c in ipairs(child) do
if c.type == 'global' and c.cate == 'type' then
for _, set in ipairs(c:getSets(uri)) do
if set.type == 'doc.class' then
skipTable = true
break
end
end
end
if skipTable then
break
end
end
if not skipTable then
skipTable = false
end
end
local nodeName = vm.getNodeName(n)
if nodeName
and not (nodeName == 'nil' and weakNil) and not (skipTable and n.type == 'table')
and vm.isSubType(uri, n, parent, mark, errs) == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_UNION_DISMATCH'
errs[#errs+1] = n
errs[#errs+1] = parent
end
return false
end
end
if not weakNil and child:isOptional() then
if vm.isSubType(uri, 'nil', parent, mark, errs) == false then
if errs then
errs[#errs+1] = 'TYPE_ERROR_OPTIONAL_DISMATCH'
errs[#errs+1] = parent
end
return false
end
end
return true
end
end
---@cast child vm.node.object
local childName = vm.getNodeName(child)
if childName == 'any'
or childName == 'unknown' then
return true
end
if not childName
or isAlias(childName, uri) then
return nil
end
if type(parent) == 'string' then
local global = vm.getGlobal('type', parent)
if not global then
return false
end
parent = global
elseif parent.type == 'vm.node' then
local hasKnownType = 0
local i = 0
for n in parent:eachObject() do
i = i + 1
if i > 100 then
break
end
if vm.getNodeName(n) then
local res = vm.isSubType(uri, child, n, mark, errs)
if res == true then
return true
elseif res == false then
hasKnownType = hasKnownType + 1
end
end
if n.type == 'doc.generic.name' then
return true
end
end
if parent:isOptional() then
if vm.isSubType(uri, child, 'nil', mark, errs) == true then
return true
end
end
if hasKnownType > 0 then
if errs
and hasKnownType > 1
and #vm.getInfer(parent):getSubViews(uri) > 1 then
errs[#errs+1] = 'TYPE_ERROR_PARENT_ALL_DISMATCH'
errs[#errs+1] = child
errs[#errs+1] = parent
end
return false
end
return true
end
---@cast parent vm.node.object
local parentName = vm.getNodeName(parent)
if parentName == 'any'
or parentName == 'unknown' then
return true
end
if not parentName
or isAlias(parentName, uri) then
return nil
end
if childName == parentName then
if not checkValue(parent, child, mark, errs) then
return false
end
return true
end
if parentName == 'number' and childName == 'integer' then
return true
end
if parentName == 'integer' and childName == 'number' then
if config.get(uri, 'Lua.type.castNumberToInteger') then
return true
end
if child.type == 'number'
and child[1]
and not math.tointeger(child[1]) then
if errs then
errs[#errs+1] = 'TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER'
errs[#errs+1] = child[1]
end
return false
end
if child.type == 'global'
and child.cate == 'type' then
if errs then
errs[#errs+1] = 'TYPE_ERROR_NUMBER_TYPE_TO_INTEGER'
end
return false
end
return true
end
local result = checkParentEnum(parentName, child, uri, mark, errs)
if result ~= nil then
return result
end
result = checkChildEnum(childName, parent, uri, mark, errs)
if result ~= nil then
return result
end
if parentName == 'table' and not guide.isBasicType(childName) then
return true
end
if childName == 'table' and not guide.isBasicType(parentName) then
if config.get(uri, 'Lua.type.checkTableShape') then
return checkTableShape(parent, child, uri, mark, errs)
else
return true
end
end
-- check class parent
if childName and not mark[childName] then
mark[childName] = true
local isBasicType = guide.isBasicType(childName)
local childClass = vm.getGlobal('type', childName)
if childClass then
for _, set in ipairs(childClass:getSets(uri)) do
if set.type == 'doc.class' and set.extends then
for _, ext in ipairs(set.extends) do
if ext.type == 'doc.extends.name'
and (not isBasicType or guide.isBasicType(ext[1]))
and vm.isSubType(uri, ext[1], parent, mark, errs) == true then
mark[childName] = nil
return true
end
end
end
end
end
mark[childName] = nil
end
--[[
---@class A: string
---@type A
local x = '' --> `string` set to `A`
]]
if guide.isBasicType(childName)
and not mark[childName] then
mark[childName] = true
if vm.isSubType(uri, parentName, childName, mark) then
mark[childName] = nil
return true
end
mark[childName] = nil
end
if errs then
errs[#errs+1] = 'TYPE_ERROR_DISMATCH'
errs[#errs+1] = child
errs[#errs+1] = parent
end
return false
end
---@param node string|vm.node|vm.object
function vm.isUnknown(node)
if type(node) == 'string' then
return node == 'unknown'
end
if node.type == 'vm.node' then
return not node:hasKnownType()
end
return false
end
---@param uri uri
---@param tnode vm.node
---@param knode vm.node|string
---@param inversion? boolean
---@return vm.node?
function vm.getTableValue(uri, tnode, knode, inversion)
local result = vm.createNode()
local inferSize = config.get(uri, "Lua.type.inferTableSize")
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
if field.extends then
if inversion then
if vm.isSubType(uri, vm.compileNode(field.name), knode) then
result:merge(vm.compileNode(field.extends))
end
else
if vm.isSubType(uri, knode, vm.compileNode(field.name)) then
result:merge(vm.compileNode(field.extends))
end
end
end
end
end
if tn.type == 'doc.type.array' then
result:merge(vm.compileNode(tn.node))
end
if tn.type == 'table' then
if vm.isUnknown(knode) then
goto CONTINUE
end
for _, field in ipairs(tn) do
if field.type == 'tableindex'
and field.value then
result:merge(vm.compileNode(field.value))
end
if field.type == 'tablefield'
and field.value then
if inversion then
if vm.isSubType(uri, 'string', knode) then
result:merge(vm.compileNode(field.value))
end
else
if vm.isSubType(uri, knode, 'string') then
result:merge(vm.compileNode(field.value))
end
end
end
if field.type == 'tableexp'
and field.value
and field.tindex <= inferSize then
if inversion then
if vm.isSubType(uri, 'integer', knode) then
result:merge(vm.compileNode(field.value))
end
else
if vm.isSubType(uri, knode, 'integer') then
result:merge(vm.compileNode(field.value))
end
end
end
if field.type == 'varargs' then
result:merge(vm.compileNode(field))
end
end
end
::CONTINUE::
end
if result:isEmpty() then
return nil
end
return result
end
---@param uri uri
---@param tnode vm.node
---@param vnode vm.node|string|vm.object
---@param reverse? boolean
---@return vm.node?
function vm.getTableKey(uri, tnode, vnode, reverse)
local result = vm.createNode()
for tn in tnode:eachObject() do
if tn.type == 'doc.type.table' then
for _, field in ipairs(tn.fields) do
if field.name.type ~= 'doc.field.name'
and field.extends then
if reverse then
if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then
result:merge(vm.compileNode(field.name))
end
else
if vm.isSubType(uri, vnode, vm.compileNode(field.extends)) then
result:merge(vm.compileNode(field.name))
end
end
end
end
end
if tn.type == 'doc.type.array' then
result:merge(vm.declareGlobal('type', 'integer'))
end
if tn.type == 'table' then
if vm.isUnknown(tnode) then
goto CONTINUE
end
for _, field in ipairs(tn) do
if field.type == 'tableindex' then
if field.index then
result:merge(vm.compileNode(field.index))
end
end
if field.type == 'tablefield' then
result:merge(vm.declareGlobal('type', 'string'))
end
if field.type == 'tableexp' or field.type == 'varargs' then
result:merge(vm.declareGlobal('type', 'integer'))
end
end
end
::CONTINUE::
end
if result:isEmpty() then
return nil
end
return result
end
---@param uri uri
---@param defNode vm.node
---@param refNode vm.node
---@param errs typecheck.err[]?
---@return boolean
function vm.canCastType(uri, defNode, refNode, errs)
local defInfer = vm.getInfer(defNode)
local refInfer = vm.getInfer(refNode)
if defInfer:hasAny(uri) then
return true
end
if refInfer:hasAny(uri) then
return true
end
if defInfer:view(uri) == 'unknown' then
return true
end
if refInfer:view(uri) == 'unknown' then
return true
end
if defInfer:view(uri) == 'nil' then
return true
end
if vm.isSubType(uri, refNode, 'nil') then
-- allow `local x = {};x = nil`,
-- but not allow `local x ---@type table;x = nil`
if defInfer:hasType(uri, 'table')
and not defNode:hasType 'table' then
return true
end
end
if vm.isSubType(uri, refNode, 'number') then
-- allow `local x = 0;x = 1.0`,
-- but not allow `local x ---@type integer;x = 1.0`
if defInfer:hasType(uri, 'integer')
and not defNode:hasType 'integer' then
return true
end
end
if vm.isSubType(uri, refNode, defNode, {}, errs) then
return true
end
return false
end
local ErrorMessageMap = {
TYPE_ERROR_ENUM_GLOBAL_DISMATCH = {'child', 'parent'},
TYPE_ERROR_ENUM_GENERIC_UNSUPPORTED = {'child'},
TYPE_ERROR_ENUM_LITERAL_DISMATCH = {'child', 'parent'},
TYPE_ERROR_ENUM_OBJECT_DISMATCH = {'child', 'parent'},
TYPE_ERROR_ENUM_NO_OBJECT = {'child'},
TYPE_ERROR_INTEGER_DISMATCH = {'child', 'parent'},
TYPE_ERROR_STRING_DISMATCH = {'child', 'parent'},
TYPE_ERROR_BOOLEAN_DISMATCH = {'child', 'parent'},
TYPE_ERROR_TABLE_NO_FIELD = {'key'},
TYPE_ERROR_TABLE_FIELD_DISMATCH = {'key', 'child', 'parent'},
TYPE_ERROR_CHILD_ALL_DISMATCH = {'child', 'parent'},
TYPE_ERROR_PARENT_ALL_DISMATCH = {'child', 'parent'},
TYPE_ERROR_UNION_DISMATCH = {'child', 'parent'},
TYPE_ERROR_OPTIONAL_DISMATCH = {'parent'},
TYPE_ERROR_NUMBER_LITERAL_TO_INTEGER = {'child'},
TYPE_ERROR_NUMBER_TYPE_TO_INTEGER = {},
TYPE_ERROR_DISMATCH = {'child', 'parent'},
DIAG_MISSING_FIELDS = {"1", "2"},
}
---@param uri uri
---@param errs typecheck.err[]
---@return string
function vm.viewTypeErrorMessage(uri, errs)
local lines = {}
local mark = {}
local index = 1
while true do
local name = errs[index]
if not name then
break
end
index = index + 1
local params = ErrorMessageMap[name]
local lparams = {}
for _, paramName in ipairs(params) do
local value = errs[index]
if type(value) == 'string'
or type(value) == 'number'
or type(value) == 'boolean' then
lparams[paramName] = util.viewLiteral(value)
elseif value.type == 'global' then
lparams[paramName] = value.name
elseif value.type == 'vm.node' then
---@cast value vm.node
lparams[paramName] = vm.getInfer(value):view(uri)
elseif value.type == 'table' then
lparams[paramName] = 'table'
elseif value.type == 'generic' then
---@cast value vm.generic
lparams[paramName] = vm.getInfer(value):view(uri)
elseif value.type == 'variable' then
else
---@cast value -string, -vm.global, -vm.node, -vm.generic, -vm.variable
if paramName == 'key' then
lparams[paramName] = vm.viewKey(value, uri)
else
lparams[paramName] = vm.getInfer(value):view(uri)
or vm.getInfer(value):view(uri)
end
end
index = index + 1
end
local line = lang.script(name, lparams)
if not mark[line] then
mark[line] = true
lines[#lines+1] = '- ' .. line
end
end
util.revertArray(lines)
if #lines > 15 then
lines[13] = ('...(+%d)'):format(#lines - 15)
table.move(lines, #lines - 2, #lines, 14)
return table.concat(lines, '\n', 1, 16)
else
return table.concat(lines, '\n')
end
end
---@param name string
---@param uri uri
---@return parser.object[]?
function vm.getOverloadsByTypeName(name, uri)
local global = vm.getGlobal('type', name)
if not global then
return nil
end
local results
for _, set in ipairs(global:getSets(uri)) do
for _, doc in ipairs(set.bindGroup) do
if doc.type == 'doc.overload' then
if not results then
results = {}
end
results[#results+1] = doc.overload
end
end
end
return results
end

View File

@@ -0,0 +1,246 @@
local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
---@param source parser.object?
---@return boolean|nil
function vm.testCondition(source)
if not source then
return nil
end
local node = vm.compileNode(source)
if node.optional then
return nil
end
local hasTrue, hasFalse
for n in node:eachObject() do
if n.type == 'boolean'
or n.type == 'doc.type.boolean' then
if n[1] == true then
hasTrue = true
end
if n[1] == false then
hasFalse = true
end
elseif n.type == 'global' and n.cate == 'type' then
if n.name == 'boolean'
or n.name == 'unknown' then
return nil
end
if n.name == 'false'
or n.name == 'nil' then
hasFalse = true
else
hasTrue = true
end
elseif n.type == 'nil' then
hasFalse = true
elseif guide.isLiteral(n) then
hasTrue = true
end
end
if hasTrue == hasFalse then
return nil
end
if hasTrue then
return true
else
return false
end
end
---@param v vm.node.object
---@return string|false
local function getUnique(v)
if v.type == 'boolean' then
if v[1] == nil then
return false
end
return ('%s'):format(v[1])
end
if v.type == 'number' then
if not v[1] then
return false
end
return ('num:%s'):format(v[1])
end
if v.type == 'integer' then
if not v[1] then
return false
end
return ('num:%s'):format(v[1])
end
if v.type == 'table' then
---@cast v parser.object
return ('table:%s@%d'):format(guide.getUri(v), v.start)
end
if v.type == 'function' then
---@cast v parser.object
return ('func:%s@%d'):format(guide.getUri(v), v.start)
end
return false
end
---@param a parser.object?
---@param b parser.object?
---@return boolean|nil
function vm.equal(a, b)
if not a or not b then
return false
end
local nodeA = vm.compileNode(a)
local nodeB = vm.compileNode(b)
local mapA = {}
for obj in nodeA:eachObject() do
local unique = getUnique(obj)
if not unique then
return nil
end
mapA[unique] = true
end
for obj in nodeB:eachObject() do
local unique = getUnique(obj)
if not unique then
return nil
end
if not mapA[unique] then
return false
end
end
return true
end
---@param v vm.object?
---@return integer?
function vm.getInteger(v)
if not v then
return nil
end
local node = vm.compileNode(v)
local result
for n in node:eachObject() do
if n.type == 'integer' then
if result then
return nil
else
result = n[1]
end
elseif n.type == 'number' then
if result then
return nil
elseif not math.tointeger(n[1]) then
return nil
else
result = math.tointeger(n[1])
end
elseif n.type ~= 'local'
and n.type ~= 'global' then
return nil
end
end
return result
end
---@param v vm.object?
---@return string?
function vm.getString(v)
if not v then
return nil
end
local node = vm.compileNode(v)
local result
for n in node:eachObject() do
if n.type == 'string' then
if result then
return nil
else
result = n[1]
end
elseif n.type ~= 'local'
and n.type ~= 'global' then
return nil
end
end
return result
end
---@param v vm.object?
---@return number?
function vm.getNumber(v)
if not v then
return nil
end
local node = vm.compileNode(v)
local result
for n in node:eachObject() do
if n.type == 'number'
or n.type == 'integer' then
if result then
return nil
else
result = n[1]
end
elseif n.type ~= 'local'
and n.type ~= 'global' then
return nil
end
end
return result
end
---@param v vm.object
---@return boolean|nil
function vm.getBoolean(v)
if not v then
return nil
end
local node = vm.compileNode(v)
local result
for n in node:eachObject() do
if n.type == 'boolean' then
if result then
return nil
else
result = n[1]
end
elseif n.type ~= 'local'
and n.type ~= 'global' then
return nil
end
end
return result
end
---@param v vm.object
---@return table<any, boolean>?
---@return integer
function vm.getLiterals(v)
if not v then
return nil, 0
end
local map
local count = 0
local node = vm.compileNode(v)
for n in node:eachObject() do
local literal
if n.type == 'boolean'
or n.type == 'string'
or n.type == 'number'
or n.type == 'integer' then
literal = n[1]
end
if n.type == 'doc.type.string'
or n.type == 'doc.type.integer'
or n.type == 'doc.type.boolean' then
literal = n[1]
end
if literal ~= nil then
if not map then
map = {}
end
map[literal] = true
count = count + 1
end
end
return map, count
end

View File

@@ -0,0 +1,411 @@
local util = require 'utility'
local guide = require 'parser.guide'
---@class vm
local vm = require 'vm.vm'
---@class vm.variable
---@field uri uri
---@field root parser.object
---@field id string
---@field base parser.object
---@field sets parser.object[]
---@field gets parser.object[]
local mt = {}
mt.__index = mt
mt.type = 'variable'
---@param id string
---@return vm.variable
local function createVariable(root, id)
local variable = setmetatable({
root = root,
uri = root.uri,
id = id,
sets = {},
gets = {},
}, mt)
return variable
end
---@class parser.object
---@field package _variableNode vm.variable|false
---@field package _variableNodes table<string, vm.variable>
local compileVariables, getLoc
---@param id string
---@param source parser.object
---@param base parser.object
---@return vm.variable
local function insertVariableID(id, source, base)
local root = guide.getRoot(source)
if not root._variableNodes then
root._variableNodes = util.multiTable(2, function (lid)
local variable = createVariable(root, lid)
return variable
end)
end
local variable = root._variableNodes[id]
variable.base = base
if guide.isAssign(source) then
variable.sets[#variable.sets+1] = source
else
variable.gets[#variable.gets+1] = source
end
return variable
end
local compileSwitch = util.switch()
: case 'local'
: case 'self'
: call(function (source, base)
local id = ('%d'):format(source.start)
local variable = insertVariableID(id, source, base)
source._variableNode = variable
if not source.ref then
return
end
for _, ref in ipairs(source.ref) do
compileVariables(ref, base)
end
end)
: case 'setlocal'
: case 'getlocal'
: call(function (source, base)
local id = ('%d'):format(source.node.start)
local variable = insertVariableID(id, source, base)
source._variableNode = variable
compileVariables(source.next, base)
end)
: case 'getfield'
: case 'setfield'
: call(function (source, base)
local parentNode = source.node._variableNode
if not parentNode then
return
end
local key = guide.getKeyName(source)
if type(key) ~= 'string' then
return
end
local id = parentNode.id .. vm.ID_SPLITE .. key
local variable = insertVariableID(id, source, base)
source._variableNode = variable
source.field._variableNode = variable
if source.type == 'getfield' then
compileVariables(source.next, base)
end
end)
: case 'getmethod'
: case 'setmethod'
: call(function (source, base)
local parentNode = source.node._variableNode
if not parentNode then
return
end
local key = guide.getKeyName(source)
if type(key) ~= 'string' then
return
end
local id = parentNode.id .. vm.ID_SPLITE .. key
local variable = insertVariableID(id, source, base)
source._variableNode = variable
source.method._variableNode = variable
if source.type == 'getmethod' then
compileVariables(source.next, base)
end
end)
: case 'getindex'
: case 'setindex'
: call(function (source, base)
local parentNode = source.node._variableNode
if not parentNode then
return
end
local key = guide.getKeyName(source)
if type(key) ~= 'string' then
return
end
local id = parentNode.id .. vm.ID_SPLITE .. key
local variable = insertVariableID(id, source, base)
source._variableNode = variable
source.index._variableNode = variable
if source.type == 'setindex' then
compileVariables(source.next, base)
end
end)
local leftSwitch = util.switch()
: case 'field'
: case 'method'
: call(function (source)
return getLoc(source.parent)
end)
: case 'getfield'
: case 'setfield'
: case 'getmethod'
: case 'setmethod'
: case 'getindex'
: case 'setindex'
: call(function (source)
return getLoc(source.node)
end)
: case 'getlocal'
: call(function (source)
return source.node
end)
: case 'local'
: case 'self'
: call(function (source)
return source
end)
---@param source parser.object
---@return parser.object?
function getLoc(source)
return leftSwitch(source.type, source)
end
---@return parser.object
function mt:getBase()
return self.base
end
---@return string
function mt:getCodeName()
local name = self.id:gsub(vm.ID_SPLITE, '.'):gsub('^%d+', self.base[1])
return name
end
---@return vm.variable?
function mt:getParent()
local parentID = self.id:match('^(.+)' .. vm.ID_SPLITE)
if not parentID then
return nil
end
return self.root._variableNodes[parentID]
end
---@return string?
function mt:getFieldName()
return self.id:match(vm.ID_SPLITE .. '(.-)$')
end
---@param key? string
function mt:getSets(key)
if not key then
return self.sets
end
local id = self.id .. vm.ID_SPLITE .. key
local variable = self.root._variableNodes[id]
return variable.sets
end
---@param includeGets boolean?
function mt:getFields(includeGets)
local id = self.id
local root = self.root
-- TODOoptimize
local clock = os.clock()
local fields = {}
for lid, variable in pairs(root._variableNodes) do
if lid ~= id
and util.stringStartWith(lid, id)
and lid:sub(#id + 1, #id + 1) == vm.ID_SPLITE
-- only one field
and not lid:find(vm.ID_SPLITE, #id + 2) then
for _, src in ipairs(variable.sets) do
fields[#fields+1] = src
end
if includeGets then
for _, src in ipairs(variable.gets) do
fields[#fields+1] = src
end
end
end
end
local cost = os.clock() - clock
if cost > 1.0 then
log.warn('variable-id getFields takes %.3f seconds', cost)
end
return fields
end
---@param source parser.object
---@param base parser.object
function compileVariables(source, base)
if not source then
return
end
source._variableNode = false
if not compileSwitch:has(source.type) then
return
end
compileSwitch(source.type, source, base)
end
---@param source parser.object
---@return string?
function vm.getVariableID(source)
local variable = vm.getVariableNode(source)
if not variable then
return nil
end
return variable.id
end
---@param source parser.object
---@param key? string
---@return vm.variable?
function vm.getVariable(source, key)
local variable = vm.getVariableNode(source)
if not variable then
return nil
end
if not key then
return variable
end
local root = guide.getRoot(source)
if not root._variableNodes then
return nil
end
local id = variable.id .. vm.ID_SPLITE .. key
return root._variableNodes[id]
end
---@param source parser.object
---@return vm.variable?
function vm.getVariableNode(source)
local variable = source._variableNode
if variable ~= nil then
return variable or nil
end
source._variableNode = false
local loc = getLoc(source)
if not loc then
return nil
end
compileVariables(loc, loc)
return source._variableNode or nil
end
---@param source parser.object
---@param name string
---@return vm.variable?
function vm.getVariableInfoByCodeName(source, name)
local id = vm.getVariableID(source)
if not id then
return nil
end
local root = guide.getRoot(source)
if not root._variableNodes then
return nil
end
local headPos = name:find('.', 1, true)
if not headPos then
return root._variableNodes[id]
end
local vid = id .. name:sub(headPos):gsub('%.', vm.ID_SPLITE)
return root._variableNodes[vid]
end
---@param source parser.object
---@param key? string
---@return parser.object[]?
function vm.getVariableSets(source, key)
local variable = vm.getVariable(source, key)
if not variable then
return nil
end
return variable.sets
end
---@param source parser.object
---@param key? string
---@return parser.object[]?
function vm.getVariableGets(source, key)
local variable = vm.getVariable(source, key)
if not variable then
return nil
end
return variable.gets
end
---@param source parser.object
---@param includeGets boolean
---@return parser.object[]?
function vm.getVariableFields(source, includeGets)
local variable = vm.getVariable(source)
if not variable then
return nil
end
return variable:getFields(includeGets)
end
---@param source parser.object
---@return boolean
function vm.compileByVariable(source)
local variable = vm.getVariableNode(source)
if not variable then
return false
end
vm.setNode(source, variable)
return true
end
---@param source parser.object
local function compileSelf(source)
if source.parent.type ~= 'funcargs' then
return
end
---@type parser.object
local node = source.parent.parent and source.parent.parent.parent and source.parent.parent.parent.node
if not node then
return
end
local fields = vm.getVariableFields(source, false)
if not fields then
return
end
local variableNode = vm.getVariableNode(node)
local globalNode = vm.getGlobalNode(node)
if not variableNode and not globalNode then
return
end
for _, field in ipairs(fields) do
if field.type == 'setfield' then
local key = guide.getKeyName(field)
if key then
if variableNode then
local myID = variableNode.id .. vm.ID_SPLITE .. key
insertVariableID(myID, field, variableNode.base)
end
if globalNode then
local myID = globalNode:getName() .. vm.ID_SPLITE .. key
local myGlobal = vm.declareGlobal('variable', myID, guide.getUri(node))
myGlobal:addSet(guide.getUri(node), field)
end
end
end
end
end
---@param source parser.object
local function compileAst(source)
--[[
local mt
function mt:xxx()
self.a = 1
end
mt.a --> find this definition
]]
guide.eachSourceType(source, 'self', function (src)
compileSelf(src)
end)
end
return {
compileAst = compileAst,
}

View File

@@ -0,0 +1,205 @@
---@class vm
local vm = require 'vm.vm'
local guide = require 'parser.guide'
local config = require 'config'
local glob = require 'glob'
---@class parser.object
---@field package _visibleType? parser.visibleType
local function globMatch(patterns, fieldName)
return glob.glob(patterns)(fieldName)
end
local function luaMatch(patterns, fieldName)
for i = 1, #patterns do
if string.find(fieldName, patterns[i]) then
return true
end
end
return false
end
local function getVisibleType(source)
if guide.isLiteral(source) then
return 'public'
end
if source._visibleType then
return source._visibleType
end
if source.type == 'doc.field' then
if source.visible then
source._visibleType = source.visible
return source.visible
end
end
if source.bindDocs then
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.private' then
source._visibleType = 'private'
return 'private'
end
if doc.type == 'doc.protected' then
source._visibleType = 'protected'
return 'protected'
end
if doc.type == 'doc.package' then
source._visibleType = 'package'
return 'package'
end
end
end
local fieldName = guide.getKeyName(source)
if type(fieldName) == 'string' then
local uri = guide.getUri(source)
local regengine = config.get(uri, 'Lua.doc.regengine')
local match = regengine == "glob" and globMatch or luaMatch
local privateNames = config.get(uri, 'Lua.doc.privateName')
if #privateNames > 0 and match(privateNames, fieldName) then
source._visibleType = 'private'
return 'private'
end
local protectedNames = config.get(uri, 'Lua.doc.protectedName')
if #protectedNames > 0 and match(protectedNames, fieldName) then
source._visibleType = 'protected'
return 'protected'
end
local packageNames = config.get(uri, 'Lua.doc.packageName')
if #packageNames > 0 and match(packageNames, fieldName) then
source._visibleType = 'package'
return 'package'
end
end
source._visibleType = 'public'
return 'public'
end
---@class vm.node
---@field package _visibleType parser.visibleType
---@param source parser.object
---@return parser.visibleType
function vm.getVisibleType(source)
local node = vm.compileNode(source)
if node._visibleType then
return node._visibleType
end
for _, def in ipairs(vm.getDefs(source)) do
local visible = getVisibleType(def)
if visible ~= 'public' then
node._visibleType = visible
return visible
end
end
node._visibleType = 'public'
return 'public'
end
---@param source parser.object
---@return vm.global?
function vm.getParentClass(source)
if source.type == 'doc.field' then
return vm.getGlobalNode(source.class)
end
if source.type == 'setfield'
or source.type == 'setindex'
or source.type == 'setmethod'
or source.type == 'tableindex' then
return vm.getDefinedClass(guide.getUri(source), source.node)
end
if source.type == 'tablefield' then
return vm.getDefinedClass(guide.getUri(source), source.node) or
vm.getDefinedClass(guide.getUri(source), source.parent.parent)
end
return nil
end
---@param suri uri
---@param source parser.object
---@return vm.global?
function vm.getDefinedClass(suri, source)
source = guide.getSelfNode(source) or source
local sets = vm.getVariableSets(source)
if sets then
for _, set in ipairs(sets) do
if set.bindDocs then
for _, doc in ipairs(set.bindDocs) do
if doc.type == 'doc.class' then
return vm.getGlobalNode(doc)
end
end
end
end
end
local global = vm.getGlobalNode(source)
if global then
for _, set in ipairs(global:getSets(suri)) do
if set.bindDocs then
for _, doc in ipairs(set.bindDocs) do
if doc.type == 'doc.class' then
return vm.getGlobalNode(doc)
end
end
end
end
end
return nil
end
---@param source parser.object
---@return vm.global?
local function getEnvClass(source)
local func = guide.getParentFunction(source)
if not func or func.type ~= 'function' then
return nil
end
local parent = func.parent
if parent.type == 'setfield'
or parent.type == 'setmethod' then
local node = parent.node
return vm.getDefinedClass(guide.getUri(source), node)
end
return nil
end
---@param parent parser.object
---@param field parser.object
function vm.isVisible(parent, field)
local visible = vm.getVisibleType(field)
if visible == 'public' then
return true
end
if visible == 'package' then
return guide.getUri(parent) == guide.getUri(field)
end
local class = vm.getParentClass(field)
if not class then
return true
end
local suri = guide.getUri(parent)
-- check <?obj?>.x
local myClass = vm.getDefinedClass(suri, parent)
if not myClass then
-- check function <?mt?>:X() ... end
myClass = getEnvClass(parent)
if not myClass then
return false
end
end
if myClass == class then
return true
end
if visible == 'protected' then
if vm.isSubType(suri, myClass, class) then
return true
end
end
return false
end

View File

@@ -0,0 +1,119 @@
local guide = require 'parser.guide'
local files = require 'files'
local timer = require 'timer'
local setmetatable = setmetatable
local log = log
local xpcall = xpcall
local mathHuge = math.huge
local weakMT = { __mode = 'kv' }
---@class vm
local m = {}
m.ID_SPLITE = '\x1F'
function m.getSpecial(source)
if not source then
return nil
end
return source.special
end
---@param source parser.object
---@return string?
function m.getKeyName(source)
if not source then
return nil
end
if source.type == 'call' then
local special = m.getSpecial(source.node)
if special == 'rawset'
or special == 'rawget' then
return guide.getKeyNameOfLiteral(source.args[2])
end
end
return guide.getKeyName(source)
end
function m.getKeyType(source)
if not source then
return nil
end
if source.type == 'call' then
local special = m.getSpecial(source.node)
if special == 'rawset'
or special == 'rawget' then
return guide.getKeyTypeOfLiteral(source.args[2])
end
end
return guide.getKeyType(source)
end
---@param source parser.object
---@return parser.object?
function m.getObjectValue(source)
if source.value then
return source.value
end
if source.special == 'rawset' then
return source.args and source.args[3]
end
return nil
end
---@param source parser.object
---@return parser.object?
function m.getObjectFunctionValue(source)
local value = m.getObjectValue(source)
if value == nil then return end
if value.type == 'function' or value.type == 'doc.type.function' then
return value
end
if value.type == 'getlocal' then
return m.getObjectFunctionValue(value.node)
end
return value
end
m.cacheTracker = setmetatable({}, weakMT)
function m.flushCache()
if m.cache then
m.cache.dead = true
end
m.cacheVersion = files.globalVersion
m.cache = {}
m.cacheActiveTime = mathHuge
m.locked = setmetatable({}, weakMT)
m.cacheTracker[m.cache] = true
end
function m.getCache(name, weak)
if m.cacheVersion ~= files.globalVersion then
m.flushCache()
end
m.cacheActiveTime = timer.clock()
if not m.cache[name] then
m.cache[name] = weak and setmetatable({}, weakMT) or {}
end
return m.cache[name]
end
local function init()
m.flushCache()
-- 可以在一段时间不活动后清空缓存,不过目前看起来没有必要
--timer.loop(1, function ()
-- if timer.clock() - m.cacheActiveTime > 10.0 then
-- log.info('Flush cache: Inactive')
-- m.flushCache()
-- collectgarbage()
-- end
--end)
end
xpcall(init, log.error)
return m