diff --git a/internal/parser/ast.go b/internal/parser/ast.go index 78ffb07..afe6bc6 100644 --- a/internal/parser/ast.go +++ b/internal/parser/ast.go @@ -153,3 +153,12 @@ type BinaryExpression struct { func (b *BinaryExpression) Pos() Position { return b.Position } func (b *BinaryExpression) isValue() {} + +type UnaryExpression struct { + Position Position + Operator Token + Right Value +} + +func (u *UnaryExpression) Pos() Position { return u.Position } +func (u *UnaryExpression) isValue() {} diff --git a/internal/parser/lexer.go b/internal/parser/lexer.go index 2bbead7..cbfcfa7 100644 --- a/internal/parser/lexer.go +++ b/internal/parser/lexer.go @@ -147,18 +147,12 @@ func (l *Lexer) NextToken() Token { case ']': return l.emit(TokenRBracket) case '+': - if unicode.IsSpace(l.peek()) { + if unicode.IsSpace(l.peek()) || unicode.IsDigit(l.peek()) { return l.emit(TokenPlus) } return l.lexObjectIdentifier() case '-': - if unicode.IsDigit(l.peek()) { - return l.lexNumber() - } - if unicode.IsSpace(l.peek()) { - return l.emit(TokenMinus) - } - return l.lexIdentifier() + return l.emit(TokenMinus) case '*': return l.emit(TokenStar) case '/': @@ -242,13 +236,28 @@ func (l *Lexer) lexString() Token { } func (l *Lexer) lexNumber() Token { - for { - r := l.next() - if unicode.IsDigit(r) || unicode.IsLetter(r) || r == '.' || r == '-' || r == '+' { - continue + // Consume initial digits (already started) + l.lexDigits() + + if l.peek() == '.' { + l.next() + l.lexDigits() + } + + if r := l.peek(); r == 'e' || r == 'E' { + l.next() + if p := l.peek(); p == '+' || p == '-' { + l.next() } - l.backup() - return l.emit(TokenNumber) + l.lexDigits() + } + + return l.emit(TokenNumber) +} + +func (l *Lexer) lexDigits() { + for unicode.IsDigit(l.peek()) { + l.next() } } @@ -318,7 +327,7 @@ func (l *Lexer) lexHashIdentifier() Token { func (l *Lexer) lexVariableReference() Token { for { r := l.next() - if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-' { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { continue } l.backup() diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 6d1f73b..6bcee1f 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -299,8 +299,27 @@ func (p *Parser) parseAtom() (Value, bool) { return &ReferenceValue{Position: tok.Position, Value: tok.Value}, true case TokenVariableReference: return &VariableReferenceValue{Position: tok.Position, Name: tok.Value}, true + case TokenMinus: + val, ok := p.parseAtom() + if !ok { + return nil, false + } + return &UnaryExpression{Position: tok.Position, Operator: tok, Right: val}, true case TokenObjectIdentifier: return &VariableReferenceValue{Position: tok.Position, Name: tok.Value}, true + case TokenSymbol: + if tok.Value == "(" { + val, ok := p.parseExpression(0) + if !ok { + return nil, false + } + if next := p.next(); next.Type != TokenSymbol || next.Value != ")" { + p.addError(next.Position, "expected )") + return nil, false + } + return val, true + } + fallthrough case TokenLBrace: arr := &ArrayValue{Position: tok.Position} for { diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 4e3f74e..7be86ff 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -236,6 +236,108 @@ func (v *Validator) valueToInterface(val parser.Value, ctx *index.ProjectNode) i arr = append(arr, v.valueToInterface(e, ctx)) } return arr + case *parser.BinaryExpression: + left := v.valueToInterface(t.Left, ctx) + right := v.valueToInterface(t.Right, ctx) + return v.evaluateBinary(left, t.Operator.Type, right) + case *parser.UnaryExpression: + val := v.valueToInterface(t.Right, ctx) + return v.evaluateUnary(t.Operator.Type, val) + } + return nil +} + +func (v *Validator) evaluateBinary(left interface{}, op parser.TokenType, right interface{}) interface{} { + if left == nil || right == nil { + return nil + } + + if op == parser.TokenConcat { + return fmt.Sprintf("%v%v", left, right) + } + + toInt := func(val interface{}) (int64, bool) { + switch v := val.(type) { + case int64: + return v, true + case int: + return int64(v), true + } + return 0, false + } + + toFloat := func(val interface{}) (float64, bool) { + switch v := val.(type) { + case float64: + return v, true + case int64: + return float64(v), true + case int: + return float64(v), true + } + return 0, false + } + + if l, ok := toInt(left); ok { + if r, ok := toInt(right); ok { + switch op { + case parser.TokenPlus: + return l + r + case parser.TokenMinus: + return l - r + case parser.TokenStar: + return l * r + case parser.TokenSlash: + if r != 0 { + return l / r + } + case parser.TokenPercent: + if r != 0 { + return l % r + } + } + } + } + + if l, ok := toFloat(left); ok { + if r, ok := toFloat(right); ok { + switch op { + case parser.TokenPlus: + return l + r + case parser.TokenMinus: + return l - r + case parser.TokenStar: + return l * r + case parser.TokenSlash: + if r != 0 { + return l / r + } + } + } + } + + return nil +} + +func (v *Validator) evaluateUnary(op parser.TokenType, val interface{}) interface{} { + if val == nil { + return nil + } + + switch op { + case parser.TokenMinus: + switch v := val.(type) { + case int64: + return -v + case float64: + return -v + } + case parser.TokenSymbol: // ! is Symbol? + // Parser uses TokenSymbol for ! ? + // Lexer: '!' -> Symbol. + if b, ok := val.(bool); ok { + return !b + } } return nil } diff --git a/test/expression_parsing_test.go b/test/expression_parsing_test.go new file mode 100644 index 0000000..1643141 --- /dev/null +++ b/test/expression_parsing_test.go @@ -0,0 +1,60 @@ +package integration + +import ( + "os" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/builder" +) + +func TestExpressionParsing(t *testing.T) { + content := ` +#var A: int = 10 +#var B: int = 2 + ++Obj = { + // 1. Multiple variables + Expr1 = @A + @B + @A + + // 2. Brackets + Expr2 = (@A + 2) * @B + + // 3. No space operator (variable name strictness) + Expr3 = @A-2 +} +` + f, _ := os.CreateTemp("", "expr_test.marte") + f.WriteString(content) + f.Close() + defer os.Remove(f.Name()) + + b := builder.NewBuilder([]string{f.Name()}, nil) + + outF, _ := os.CreateTemp("", "out.marte") + defer os.Remove(outF.Name()) + + err := b.Build(outF) + if err != nil { + t.Fatalf("Build failed: %v", err) + } + outF.Close() + + outContent, _ := os.ReadFile(outF.Name()) + outStr := string(outContent) + + // Expr1: 10 + 2 + 10 = 22 + if !strings.Contains(outStr, "Expr1 = 22") { + t.Errorf("Expr1 failed. Got:\n%s", outStr) + } + + // Expr2: (10 + 2) * 2 = 24 + if !strings.Contains(outStr, "Expr2 = 24") { + t.Errorf("Expr2 failed. Got:\n%s", outStr) + } + + // Expr3: 10 - 2 = 8 + if !strings.Contains(outStr, "Expr3 = 8") { + t.Errorf("Expr3 failed. Got:\n%s", outStr) + } +} diff --git a/test/expression_whitespace_test.go b/test/expression_whitespace_test.go new file mode 100644 index 0000000..a338264 --- /dev/null +++ b/test/expression_whitespace_test.go @@ -0,0 +1,39 @@ +package integration + +import ( + "os" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/builder" +) + +func TestExpressionWhitespace(t *testing.T) { + content := ` ++Obj = { + NoSpace = 2+2 + WithSpace = 2 + 2 +} +` + f, _ := os.CreateTemp("", "expr_ws.marte") + f.WriteString(content) + f.Close() + defer os.Remove(f.Name()) + + b := builder.NewBuilder([]string{f.Name()}, nil) + + outF, _ := os.CreateTemp("", "out.marte") + defer os.Remove(outF.Name()) + b.Build(outF) + outF.Close() + + outContent, _ := os.ReadFile(outF.Name()) + outStr := string(outContent) + + if !strings.Contains(outStr, "NoSpace = 4") { + t.Errorf("NoSpace failed. Got:\n%s", outStr) + } + if !strings.Contains(outStr, "WithSpace = 4") { + t.Errorf("WithSpace failed. Got:\n%s", outStr) + } +}