gomog/internal/engine/query.go

707 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package engine
import (
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// MatchFilter 匹配过滤器
func MatchFilter(doc map[string]interface{}, filter types.Filter) bool {
if filter == nil || len(filter) == 0 {
return true
}
for key, condition := range filter {
switch key {
case "$and":
if !handleAnd(doc, condition) {
return false
}
case "$or":
if !handleOr(doc, condition) {
return false
}
case "$nor":
if !handleNor(doc, condition) {
return false
}
case "$not":
if !handleNot(doc, condition) {
return false
}
case "$expr":
if !handleExpr(doc, condition) {
return false
}
case "$jsonSchema":
if !handleJSONSchema(doc, condition) {
return false
}
default:
if !matchField(doc, key, condition) {
return false
}
}
}
return true
}
// handleExpr 处理 $expr 操作符(聚合表达式查询)
func handleExpr(doc map[string]interface{}, condition interface{}) bool {
// 将 types.Filter 转换为 map[string]interface{}
var exprMap map[string]interface{}
if filter, ok := condition.(types.Filter); ok {
exprMap = filter
} else if m, ok := condition.(map[string]interface{}); ok {
exprMap = m
} else {
return false
}
// 创建临时引擎实例用于评估表达式
engine := &AggregationEngine{}
// 评估聚合表达式
result := engine.evaluateExpression(doc, exprMap)
// 转换为布尔值
return isTrueValue(result)
}
// isTrueValue 检查值是否为真
func isTrueValue(value interface{}) bool {
if value == nil {
return false
}
switch v := value.(type) {
case bool:
return v
case int, int8, int16, int32, int64:
return getNumericValue(v) != 0
case uint, uint8, uint16, uint32, uint64:
return getNumericValue(v) != 0
case float32, float64:
return getNumericValue(v) != 0
case string:
return v != ""
case []interface{}:
return len(v) > 0
case map[string]interface{}:
return len(v) > 0
default:
return true
}
}
// handleJSONSchema 处理 $jsonSchema 验证操作符
func handleJSONSchema(doc map[string]interface{}, schema interface{}) bool {
schemaMap, ok := schema.(map[string]interface{})
if !ok {
return true // 无效的 schema默认通过
}
return validateJSONSchema(doc, schemaMap)
}
// validateJSONSchema 验证文档是否符合 JSON Schema
func validateJSONSchema(doc map[string]interface{}, schema map[string]interface{}) bool {
// 检查 bsonType
if bsonType, exists := schema["bsonType"]; exists {
if !validateBsonType(doc, bsonType) {
return false
}
}
// 检查 required 字段
if requiredRaw, exists := schema["required"]; exists {
if required, ok := requiredRaw.([]interface{}); ok {
for _, reqField := range required {
if fieldStr, ok := reqField.(string); ok {
if doc[fieldStr] == nil {
return false
}
}
}
}
}
// 检查 properties
if propertiesRaw, exists := schema["properties"]; exists {
if properties, ok := propertiesRaw.(map[string]interface{}); ok {
for fieldName, fieldSchemaRaw := range properties {
if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok {
fieldValue := doc[fieldName]
if fieldValue != nil {
// 递归验证字段值
if !validateFieldValue(fieldValue, fieldSchema) {
return false
}
}
}
}
}
}
// 检查 enum
if enumRaw, exists := schema["enum"]; exists {
if enum, ok := enumRaw.([]interface{}); ok {
found := false
for _, val := range enum {
if compareEq(doc, val) {
found = true
break
}
}
if !found {
return false
}
}
}
// 检查 allOf
if allOfRaw, exists := schema["allOf"]; exists {
if allOf, ok := allOfRaw.([]interface{}); ok {
for _, subSchemaRaw := range allOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if !validateJSONSchema(doc, subSchema) {
return false
}
}
}
}
}
// 检查 anyOf
if anyOfRaw, exists := schema["anyOf"]; exists {
if anyOf, ok := anyOfRaw.([]interface{}); ok {
matched := false
for _, subSchemaRaw := range anyOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateJSONSchema(doc, subSchema) {
matched = true
break
}
}
}
if !matched {
return false
}
}
}
// 检查 oneOf
if oneOfRaw, exists := schema["oneOf"]; exists {
if oneOf, ok := oneOfRaw.([]interface{}); ok {
matchCount := 0
for _, subSchemaRaw := range oneOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateJSONSchema(doc, subSchema) {
matchCount++
}
}
}
if matchCount != 1 {
return false
}
}
}
// 检查 not
if notRaw, exists := schema["not"]; exists {
if notSchema, ok := notRaw.(map[string]interface{}); ok {
if validateJSONSchema(doc, notSchema) {
return false // not 要求不匹配
}
}
}
return true
}
// validateFieldValue 验证字段值是否符合 schema
func validateFieldValue(value interface{}, schema map[string]interface{}) bool {
// 检查 bsonType
if bsonType, exists := schema["bsonType"]; exists {
if !validateBsonType(value, bsonType) {
return false
}
}
// 检查 enum
if enumRaw, exists := schema["enum"]; exists {
if enum, ok := enumRaw.([]interface{}); ok {
found := false
for _, val := range enum {
if compareEq(value, val) {
found = true
break
}
}
if !found {
return false
}
}
}
// 检查 minimum - 仅当 value 是数值类型时
if minimumRaw, exists := schema["minimum"]; exists {
if num, ok := toNumber(value); ok {
if num < toFloat64(minimumRaw) {
return false
}
}
}
// 检查 maximum - 仅当 value 是数值类型时
if maximumRaw, exists := schema["maximum"]; exists {
if num, ok := toNumber(value); ok {
if num > toFloat64(maximumRaw) {
return false
}
}
}
// 检查 minLength (字符串) - 仅当 value 是字符串时
if minLengthRaw, exists := schema["minLength"]; exists {
if str, ok := value.(string); ok {
if minLen := int(toFloat64(minLengthRaw)); len(str) < minLen {
return false
}
}
}
// 检查 maxLength (字符串) - 仅当 value 是字符串时
if maxLengthRaw, exists := schema["maxLength"]; exists {
if str, ok := value.(string); ok {
if maxLen := int(toFloat64(maxLengthRaw)); len(str) > maxLen {
return false
}
}
}
// 检查 pattern (正则表达式) - 仅当 value 是字符串时
if patternRaw, exists := schema["pattern"]; exists {
if str, ok := value.(string); ok {
if pattern, ok := patternRaw.(string); ok {
if !compareRegex(str, pattern) {
return false
}
}
}
}
// 检查 items (数组元素) - 仅当 value 是数组时
if itemsRaw, exists := schema["items"]; exists {
if arr, ok := value.([]interface{}); ok {
if itemSchema, ok := itemsRaw.(map[string]interface{}); ok {
for _, item := range arr {
if itemMap, ok := item.(map[string]interface{}); ok {
if !validateJSONSchema(itemMap, itemSchema) {
return false
}
}
}
}
}
}
// 检查 minItems (数组最小长度) - 仅当 value 是数组时
if minItemsRaw, exists := schema["minItems"]; exists {
if arr, ok := value.([]interface{}); ok {
if minItems := int(toFloat64(minItemsRaw)); len(arr) < minItems {
return false
}
}
}
// 检查 maxItems (数组最大长度) - 仅当 value 是数组时
if maxItemsRaw, exists := schema["maxItems"]; exists {
if arr, ok := value.([]interface{}); ok {
if maxItems := int(toFloat64(maxItemsRaw)); len(arr) > maxItems {
return false
}
}
}
// 对于对象类型,继续递归验证嵌套 properties
if valueMap, ok := value.(map[string]interface{}); ok {
// 检查 required 字段
if requiredRaw, exists := schema["required"]; exists {
if required, ok := requiredRaw.([]interface{}); ok {
for _, reqField := range required {
if fieldStr, ok := reqField.(string); ok {
if valueMap[fieldStr] == nil {
return false
}
}
}
}
}
// 检查 properties
if propertiesRaw, exists := schema["properties"]; exists {
if properties, ok := propertiesRaw.(map[string]interface{}); ok {
for fieldName, fieldSchemaRaw := range properties {
if fieldSchema, ok := fieldSchemaRaw.(map[string]interface{}); ok {
fieldValue := valueMap[fieldName]
if fieldValue != nil {
if !validateFieldValue(fieldValue, fieldSchema) {
return false
}
}
}
}
}
}
}
// 检查 allOf
if allOfRaw, exists := schema["allOf"]; exists {
if allOf, ok := allOfRaw.([]interface{}); ok {
for _, subSchemaRaw := range allOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if !validateFieldValue(value, subSchema) {
return false
}
}
}
}
}
// 检查 anyOf
if anyOfRaw, exists := schema["anyOf"]; exists {
if anyOf, ok := anyOfRaw.([]interface{}); ok {
matched := false
for _, subSchemaRaw := range anyOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateFieldValue(value, subSchema) {
matched = true
break
}
}
}
if !matched {
return false
}
}
}
// 检查 oneOf
if oneOfRaw, exists := schema["oneOf"]; exists {
if oneOf, ok := oneOfRaw.([]interface{}); ok {
matchCount := 0
for _, subSchemaRaw := range oneOf {
if subSchema, ok := subSchemaRaw.(map[string]interface{}); ok {
if validateFieldValue(value, subSchema) {
matchCount++
}
}
}
if matchCount != 1 {
return false
}
}
}
// 检查 not
if notRaw, exists := schema["not"]; exists {
if notSchema, ok := notRaw.(map[string]interface{}); ok {
if validateFieldValue(value, notSchema) {
return false // not 要求不匹配
}
}
}
return true
}
// validateBsonType 验证 BSON 类型
func validateBsonType(value interface{}, bsonType interface{}) bool {
typeStr, ok := bsonType.(string)
if !ok {
return true
}
switch typeStr {
case "object":
_, ok := value.(map[string]interface{})
return ok
case "array":
_, ok := value.([]interface{})
return ok
case "string":
_, ok := value.(string)
return ok
case "int", "long":
switch value.(type) {
case int, int8, int16, int32, int64:
return true
default:
return false
}
case "double", "decimal":
switch value.(type) {
case float32, float64:
return true
default:
return false
}
case "bool":
_, ok := value.(bool)
return ok
case "null":
return value == nil
case "date":
_, ok := value.(time.Time)
return ok
case "objectId":
_, ok := value.(string)
return ok
default:
return true
}
}
// getNumericValue 获取数值
func getNumericValue(value interface{}) float64 {
switch v := value.(type) {
case int:
return float64(v)
case int8:
return float64(v)
case int16:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)
case uint:
return float64(v)
case uint8:
return float64(v)
case uint16:
return float64(v)
case uint32:
return float64(v)
case uint64:
return float64(v)
case float32:
return float64(v)
case float64:
return v
default:
return 0
}
}
// toArray 将值转换为数组
func toArray(value interface{}) ([]interface{}, bool) {
if arr, ok := value.([]interface{}); ok {
return arr, true
}
return nil, false
}
// toNumber 将值转换为数值
func toNumber(value interface{}) (float64, bool) {
switch v := value.(type) {
case int, int8, int16, int32, int64:
return getNumericValue(v), true
case uint, uint8, uint16, uint32, uint64:
return getNumericValue(v), true
case float32, float64:
return getNumericValue(v), true
default:
return 0, false
}
}
// handleAnd 处理 $and 操作符
func handleAnd(doc map[string]interface{}, condition interface{}) bool {
andConditions, ok := condition.([]interface{})
if !ok {
return false
}
for _, cond := range andConditions {
if condMap, ok := cond.(map[string]interface{}); ok {
if !MatchFilter(doc, condMap) {
return false
}
}
}
return true
}
// handleOr 处理 $or 操作符
func handleOr(doc map[string]interface{}, condition interface{}) bool {
orConditions, ok := condition.([]interface{})
if !ok {
return false
}
for _, cond := range orConditions {
if condMap, ok := cond.(map[string]interface{}); ok {
if MatchFilter(doc, condMap) {
return true
}
}
}
return false
}
// handleNor 处理 $nor 操作符
func handleNor(doc map[string]interface{}, condition interface{}) bool {
orConditions, ok := condition.([]interface{})
if !ok {
return false
}
for _, cond := range orConditions {
if condMap, ok := cond.(map[string]interface{}); ok {
if MatchFilter(doc, condMap) {
return false
}
}
}
return true
}
// handleNot 处理 $not 操作符
func handleNot(doc map[string]interface{}, condition interface{}) bool {
if condMap, ok := condition.(map[string]interface{}); ok {
return !MatchFilter(doc, condMap)
}
return true
}
// matchField 匹配字段
func matchField(doc map[string]interface{}, key string, condition interface{}) bool {
value := getNestedValue(doc, key)
// 处理操作符条件(支持 types.Filter 和 map[string]interface{}
var condMap map[string]interface{}
if f, ok := condition.(types.Filter); ok {
condMap = f
} else if m, ok := condition.(map[string]interface{}); ok {
condMap = m
}
if condMap != nil {
return evaluateOperators(value, condMap)
}
// 简单相等比较
return compareEq(value, condition)
}
// getNestedValue 获取嵌套字段值(支持 "a.b.c" 格式)
func getNestedValue(doc map[string]interface{}, key string) interface{} {
parts := strings.Split(key, ".")
var current interface{} = doc
for _, part := range parts {
if m, ok := current.(map[string]interface{}); ok {
current = m[part]
} else {
return nil
}
}
return current
}
// evaluateOperators 评估操作符
func evaluateOperators(value interface{}, operators map[string]interface{}) bool {
for op, operand := range operators {
switch op {
case "$eq":
if !compareEq(value, operand) {
return false
}
case "$ne":
if compareEq(value, operand) {
return false
}
case "$gt":
if !compareGt(value, operand) {
return false
}
case "$gte":
if !compareGte(value, operand) {
return false
}
case "$lt":
if !compareLt(value, operand) {
return false
}
case "$lte":
if !compareLte(value, operand) {
return false
}
case "$in":
if !compareIn(value, operand) {
return false
}
case "$nin":
if compareIn(value, operand) {
return false
}
case "$regex":
if !compareRegex(value, operand) {
return false
}
case "$exists":
exists := value != nil
if operandBool, ok := operand.(bool); ok {
if exists != operandBool {
return false
}
}
case "$type":
if !compareType(value, operand) {
return false
}
case "$all":
if !compareAll(value, operand) {
return false
}
case "$elemMatch":
if !compareElemMatch(value, operand) {
return false
}
case "$size":
if !compareSize(value, operand) {
return false
}
case "$mod":
if !compareMod(value, operand) {
return false
}
case "$bitsAllClear":
if !compareBitsAllClear(value, operand) {
return false
}
case "$bitsAllSet":
if !compareBitsAllSet(value, operand) {
return false
}
case "$bitsAnyClear":
if !compareBitsAnyClear(value, operand) {
return false
}
case "$bitsAnySet":
if !compareBitsAnySet(value, operand) {
return false
}
default:
// 未知操作符,跳过
}
}
return true
}