diff --git a/internal/index/index.go b/internal/index/index.go index 3b93b5f..1a630b2 100644 --- a/internal/index/index.go +++ b/internal/index/index.go @@ -8,13 +8,18 @@ import ( "github.com/marte-community/marte-dev-tools/internal/parser" ) +type VariableInfo struct { + Def *parser.VariableDefinition + File string +} + type ProjectTree struct { Root *ProjectNode References []Reference IsolatedFiles map[string]*ProjectNode GlobalPragmas map[string][]string NodeMap map[string][]*ProjectNode - Variables map[string]*parser.VariableDefinition + Variables map[string]VariableInfo } func (pt *ProjectTree) ScanDirectory(rootPath string) error { @@ -74,7 +79,7 @@ func NewProjectTree() *ProjectTree { }, IsolatedFiles: make(map[string]*ProjectNode), GlobalPragmas: make(map[string][]string), - Variables: make(map[string]*parser.VariableDefinition), + Variables: make(map[string]VariableInfo), } } @@ -224,7 +229,7 @@ func (pt *ProjectTree) populateNode(node *ProjectNode, file string, config *pars pt.indexValue(file, d.Value) case *parser.VariableDefinition: fileFragment.Definitions = append(fileFragment.Definitions, d) - pt.Variables[d.Name] = d + pt.Variables[d.Name] = VariableInfo{Def: d, File: file} case *parser.ObjectNode: fileFragment.Definitions = append(fileFragment.Definitions, d) norm := NormalizeName(d.Name) @@ -282,7 +287,7 @@ func (pt *ProjectTree) addObjectFragment(node *ProjectNode, file string, obj *pa pt.extractFieldMetadata(node, d) case *parser.VariableDefinition: frag.Definitions = append(frag.Definitions, d) - pt.Variables[d.Name] = d + pt.Variables[d.Name] = VariableInfo{Def: d, File: file} case *parser.ObjectNode: frag.Definitions = append(frag.Definitions, d) norm := NormalizeName(d.Name) @@ -418,7 +423,7 @@ func (pt *ProjectTree) ResolveReferences() { ref := &pt.References[i] if v, ok := pt.Variables[ref.Name]; ok { - ref.TargetVariable = v + ref.TargetVariable = v.Def continue } diff --git a/internal/schema/marte.cue b/internal/schema/marte.cue index 6ad8caa..5cb8de6 100644 --- a/internal/schema/marte.cue +++ b/internal/schema/marte.cue @@ -413,7 +413,7 @@ package schema OPCUA: {...} SysLogger: {...} GAMDataSource: { - #meta: multithreaded: bool | *false + #meta: multithreaded: false #meta: direction: "INOUT" #meta: type: "datasource" ... @@ -421,7 +421,7 @@ package schema } #Meta: { - direction?: "IN" | "OUT" | "INOUT" + direction?: "IN" | "OUT" | "INOUT" multithreaded?: bool ... } @@ -430,7 +430,7 @@ package schema // It must have a Class field. // Based on Class, it validates against #Classes. #Object: { - Class: string + Class: string "#meta"?: #Meta // Allow any other field by default (extensibility), // unless #Classes definition is closed. diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 2d20b80..f29ed3a 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -56,6 +56,7 @@ func (v *Validator) ValidateProject() { v.CheckUnused() v.CheckDataSourceThreading() v.CheckINOUTOrdering() + v.CheckVariables() } func (v *Validator) validateNode(node *index.ProjectNode) { @@ -936,53 +937,111 @@ func (v *Validator) CheckINOUTOrdering() { } for _, thread := range threads { - producedSignals := make(map[*index.ProjectNode]bool) + producedSignals := make(map[*index.ProjectNode]map[string][]*index.ProjectNode) + consumedSignals := make(map[*index.ProjectNode]map[string]bool) + gams := v.getThreadGAMs(thread) for _, gam := range gams { - v.processGAMSignalsForOrdering(gam, "InputSignals", producedSignals, true, thread, state) - v.processGAMSignalsForOrdering(gam, "OutputSignals", producedSignals, false, thread, state) + v.processGAMSignalsForOrdering(gam, "InputSignals", producedSignals, consumedSignals, true, thread, state) + v.processGAMSignalsForOrdering(gam, "OutputSignals", producedSignals, consumedSignals, false, thread, state) + } + + // Check for produced but not consumed + for ds, signals := range producedSignals { + for sigName, producers := range signals { + consumed := false + if cSet, ok := consumedSignals[ds]; ok { + if cSet[sigName] { + consumed = true + } + } + if !consumed { + for _, prod := range producers { + v.Diagnostics = append(v.Diagnostics, Diagnostic{ + Level: LevelWarning, + Message: fmt.Sprintf("INOUT Signal '%s' (DS '%s') is produced in thread '%s' but never consumed in the same thread.", sigName, ds.RealName, thread.RealName), + Position: v.getNodePosition(prod), + File: v.getNodeFile(prod), + }) + } + } + } } } } } -func (v *Validator) processGAMSignalsForOrdering(gam *index.ProjectNode, containerName string, produced map[*index.ProjectNode]bool, isInput bool, thread, state *index.ProjectNode) { +func (v *Validator) processGAMSignalsForOrdering(gam *index.ProjectNode, containerName string, produced map[*index.ProjectNode]map[string][]*index.ProjectNode, consumed map[*index.ProjectNode]map[string]bool, isInput bool, thread, state *index.ProjectNode) { container := gam.Children[containerName] if container == nil { return } for _, sig := range container.Children { - if sig.Target == nil { + fields := v.getFields(sig) + var dsNode *index.ProjectNode + var sigName string + + if sig.Target != nil { + if sig.Target.Parent != nil && sig.Target.Parent.Parent != nil { + dsNode = sig.Target.Parent.Parent + sigName = sig.Target.RealName + } + } + + if dsNode == nil { + if dsFields, ok := fields["DataSource"]; ok && len(dsFields) > 0 { + dsName := v.getFieldValue(dsFields[0]) + dsNode = v.resolveReference(dsName, v.getNodeFile(sig), isDataSource) + } + if aliasFields, ok := fields["Alias"]; ok && len(aliasFields) > 0 { + sigName = v.getFieldValue(aliasFields[0]) + } else { + sigName = sig.RealName + } + } + + if dsNode == nil || sigName == "" { continue } - targetSig := sig.Target - if targetSig.Parent == nil || targetSig.Parent.Parent == nil { - continue - } - ds := targetSig.Parent.Parent + sigName = index.NormalizeName(sigName) - if v.isMultithreaded(ds) { + if v.isMultithreaded(dsNode) { continue } - dir := v.getDataSourceDirection(ds) + dir := v.getDataSourceDirection(dsNode) if dir != "INOUT" { continue } if isInput { - if !produced[targetSig] { + isProduced := false + if set, ok := produced[dsNode]; ok { + if len(set[sigName]) > 0 { + isProduced = true + } + } + + if !isProduced { v.Diagnostics = append(v.Diagnostics, Diagnostic{ Level: LevelError, - Message: fmt.Sprintf("INOUT Signal '%s' (DS '%s') is consumed by GAM '%s' in thread '%s' (State '%s') before being produced by any previous GAM.", targetSig.RealName, ds.RealName, gam.RealName, thread.RealName, state.RealName), + Message: fmt.Sprintf("INOUT Signal '%s' (DS '%s') is consumed by GAM '%s' in thread '%s' (State '%s') before being produced by any previous GAM.", sigName, dsNode.RealName, gam.RealName, thread.RealName, state.RealName), Position: v.getNodePosition(sig), File: v.getNodeFile(sig), }) } + + if consumed[dsNode] == nil { + consumed[dsNode] = make(map[string]bool) + } + consumed[dsNode][sigName] = true } else { - produced[targetSig] = true + if produced[dsNode] == nil { + produced[dsNode] = make(map[string][]*index.ProjectNode) + } + produced[dsNode][sigName] = append(produced[dsNode][sigName], sig) } } } @@ -1003,3 +1062,42 @@ func (v *Validator) getDataSourceDirection(ds *index.ProjectNode) string { } return "" } + +func (v *Validator) CheckVariables() { + if v.Schema == nil { + return + } + ctx := v.Schema.Context + + for _, info := range v.Tree.Variables { + def := info.Def + + // Compile Type + typeVal := ctx.CompileString(def.TypeExpr) + if typeVal.Err() != nil { + v.Diagnostics = append(v.Diagnostics, Diagnostic{ + Level: LevelError, + Message: fmt.Sprintf("Invalid type expression for variable '%s': %v", def.Name, typeVal.Err()), + Position: def.Position, + File: info.File, + }) + continue + } + + if def.DefaultValue != nil { + valInterface := v.valueToInterface(def.DefaultValue) + valVal := ctx.Encode(valInterface) + + // Unify + res := typeVal.Unify(valVal) + if err := res.Validate(cue.Concrete(true)); err != nil { + v.Diagnostics = append(v.Diagnostics, Diagnostic{ + Level: LevelError, + Message: fmt.Sprintf("Variable '%s' value mismatch: %v", def.Name, err), + Position: def.Position, + File: info.File, + }) + } + } + } +} diff --git a/test/formatter_variables_test.go b/test/formatter_variables_test.go new file mode 100644 index 0000000..76d7d5d --- /dev/null +++ b/test/formatter_variables_test.go @@ -0,0 +1,45 @@ +package integration + +import ( + "bytes" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/formatter" + "github.com/marte-community/marte-dev-tools/internal/parser" +) + +func TestFormatterVariables(t *testing.T) { + content := ` +#var MyInt: int = 10 +#var MyStr: string | "A" = "default" + ++Obj = { + Field1 = $MyInt + Field2 = $MyStr +} +` + p := parser.NewParser(content) + cfg, err := p.Parse() + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + var buf bytes.Buffer + formatter.Format(cfg, &buf) + + output := buf.String() + + // Parser reconstructs type expression with spaces + if !strings.Contains(output, "#var MyInt: int = 10") { + t.Errorf("Variable MyInt formatted incorrectly. Got:\n%s", output) + } + // Note: parser adds space after each token in TypeExpr + // string | "A" -> "string | \"A\"" + if !strings.Contains(output, "#var MyStr: string | \"A\" = \"default\"") { + t.Errorf("Variable MyStr formatted incorrectly. Got:\n%s", output) + } + if !strings.Contains(output, "Field1 = $MyInt") { + t.Errorf("Variable reference $MyInt formatted incorrectly. Got:\n%s", output) + } +} diff --git a/test/lsp_inout_test.go b/test/lsp_inout_test.go new file mode 100644 index 0000000..90f489c --- /dev/null +++ b/test/lsp_inout_test.go @@ -0,0 +1,73 @@ +package integration + +import ( + "bytes" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/index" + "github.com/marte-community/marte-dev-tools/internal/lsp" + "github.com/marte-community/marte-dev-tools/internal/schema" +) + +func TestLSPINOUTOrdering(t *testing.T) { + lsp.Tree = index.NewProjectTree() + lsp.Documents = make(map[string]string) + // Mock schema if necessary, but we rely on internal schema + lsp.GlobalSchema = schema.LoadFullSchema(".") + + var buf bytes.Buffer + lsp.Output = &buf + + content := ` ++App = { + Class = RealTimeApplication + +Data = { + Class = ReferenceContainer + +DDB = { + Class = GAMDataSource + } + } + +Functions = { + Class = ReferenceContainer + +A = { + Class = IOGAM + InputSignals = { + A = { + DataSource = DDB + Type = uint32 + } + } + OutputSignals = { + B = { + DataSource = DDB + Type = uint32 + } + } + } + } + +States = { + Class = ReferenceContainer + +State = { + Class =RealTimeState + Threads = { + +Th1 = { + Class = RealTimeThread + Functions = {A} + } + } + } + } +} +` + uri := "file://app.marte" + lsp.HandleDidOpen(lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{URI: uri, Text: content}, + }) + + output := buf.String() + if !strings.Contains(output, "INOUT Signal 'A'") { + t.Error("LSP did not report INOUT ordering error") + t.Log(output) + } +} diff --git a/test/lsp_inout_warning_test.go b/test/lsp_inout_warning_test.go new file mode 100644 index 0000000..1a01418 --- /dev/null +++ b/test/lsp_inout_warning_test.go @@ -0,0 +1,66 @@ +package integration + +import ( + "bytes" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/index" + "github.com/marte-community/marte-dev-tools/internal/lsp" + "github.com/marte-community/marte-dev-tools/internal/schema" +) + +func TestLSPINOUTWarning(t *testing.T) { + lsp.Tree = index.NewProjectTree() + lsp.Documents = make(map[string]string) + lsp.GlobalSchema = schema.LoadFullSchema(".") + + var buf bytes.Buffer + lsp.Output = &buf + + content := ` ++App = { + Class = RealTimeApplication + +Data = { + Class = ReferenceContainer + +DDB = { + Class = GAMDataSource + } + } + +Functions = { + Class = ReferenceContainer + +Producer = { + Class = IOGAM + OutputSignals = { + ProducedSig = { + DataSource = DDB + Type = uint32 + } + } + } + } + +States = { + Class = ReferenceContainer + +State = { + Class =RealTimeState + Threads = { + +Th1 = { + Class = RealTimeThread + Functions = {Producer} + } + } + } + } +} +` + uri := "file://warning.marte" + lsp.HandleDidOpen(lsp.DidOpenTextDocumentParams{ + TextDocument: lsp.TextDocumentItem{URI: uri, Text: content}, + }) + + output := buf.String() + if !strings.Contains(output, "produced in thread '+Th1' but never consumed") { + t.Error("LSP did not report INOUT usage warning") + t.Log(output) + } +} diff --git a/test/validator_inout_ordering_test.go b/test/validator_inout_ordering_test.go new file mode 100644 index 0000000..c7ed2c4 --- /dev/null +++ b/test/validator_inout_ordering_test.go @@ -0,0 +1,93 @@ +package integration + +import ( + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/index" + "github.com/marte-community/marte-dev-tools/internal/parser" + "github.com/marte-community/marte-dev-tools/internal/validator" +) + +func TestINOUTOrdering(t *testing.T) { + content := ` ++Data = { + Class = ReferenceContainer + +MyDS = { + Class = GAMDataSource + #meta = { multithreaded = false } // Explicitly false + Signals = { Sig1 = { Type = uint32 } } + } +} ++GAM_Consumer = { + Class = IOGAM + InputSignals = { + Sig1 = { DataSource = MyDS Type = uint32 } + } +} ++GAM_Producer = { + Class = IOGAM + OutputSignals = { + Sig1 = { DataSource = MyDS Type = uint32 } + } +} ++App = { + Class = RealTimeApplication + +States = { + Class = ReferenceContainer + +State1 = { + Class = RealTimeState + +Thread1 = { + Class = RealTimeThread + Functions = { GAM_Consumer, GAM_Producer } // Fail + } + } + +State2 = { + Class = RealTimeState + +Thread2 = { + Class = RealTimeThread + Functions = { GAM_Producer, GAM_Consumer } // Pass + } + } + } +} +` + pt := index.NewProjectTree() + p := parser.NewParser(content) + cfg, err := p.Parse() + if err != nil { + t.Fatal(err) + } + pt.AddFile("main.marte", cfg) + + // Use validator with default schema (embedded) + // We pass "." but it shouldn't matter if no .marte_schema.cue exists + v := validator.NewValidator(pt, ".") + v.ValidateProject() + + foundError := false + for _, d := range v.Diagnostics { + if strings.Contains(d.Message, "consumed by GAM '+GAM_Consumer'") && + strings.Contains(d.Message, "before being produced") { + foundError = true + } + } + + if !foundError { + t.Error("Expected INOUT ordering error for State1") + for _, d := range v.Diagnostics { + t.Logf("Diag: %s", d.Message) + } + } + + foundErrorState2 := false + for _, d := range v.Diagnostics { + if strings.Contains(d.Message, "State '+State2'") && strings.Contains(d.Message, "before being produced") { + foundErrorState2 = true + } + } + + if foundErrorState2 { + t.Error("Unexpected INOUT ordering error for State2 (Correct order)") + } +} diff --git a/test/variables_test.go b/test/variables_test.go new file mode 100644 index 0000000..2b0637c --- /dev/null +++ b/test/variables_test.go @@ -0,0 +1,72 @@ +package integration + +import ( + "os" + "strings" + "testing" + + "github.com/marte-community/marte-dev-tools/internal/builder" + "github.com/marte-community/marte-dev-tools/internal/parser" +) + +func TestVariables(t *testing.T) { + content := ` +#var MyInt: int = 10 +#var MyStr: string = "default" + ++Obj = { + Class = Test + Field1 = $MyInt + Field2 = $MyStr +} +` + // Test Parsing + p := parser.NewParser(content) + cfg, err := p.Parse() + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + // Check definitions: #var, #var, +Obj + if len(cfg.Definitions) != 3 { + t.Errorf("Expected 3 definitions, got %d", len(cfg.Definitions)) + } + + // Test Builder resolution + f, _ := os.CreateTemp("", "vars.marte") + f.WriteString(content) + f.Close() + defer os.Remove(f.Name()) + + // Build with override + overrides := map[string]string{ + "MyInt": "999", + } + + b := builder.NewBuilder([]string{f.Name()}, overrides) + + outF, _ := os.CreateTemp("", "out.marte") + outName := outF.Name() + defer os.Remove(outName) + + err = b.Build(outF) + outF.Close() + + if err != nil { + t.Fatalf("Build failed: %v", err) + } + + outContent, _ := os.ReadFile(outName) + outStr := string(outContent) + + if !strings.Contains(outStr, "Field1 = 999") { + t.Errorf("Variable override failed for MyInt. Got:\n%s", outStr) + } + if !strings.Contains(outStr, "Field2 = \"default\"") { + t.Errorf("Default value failed for MyStr. Got:\n%s", outStr) + } + // Check #var is removed + if strings.Contains(outStr, "#var") { + t.Error("#var definition present in output") + } +}