Skip to content

Commit 9bd69b6

Browse files
committed
session_rpcserver: pass through appropriate priv map
1 parent 25aae89 commit 9bd69b6

File tree

1 file changed

+51
-48
lines changed

1 file changed

+51
-48
lines changed

session_rpcserver.go

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,55 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
841841
privacy := !req.NoPrivacyMapper
842842
privacyMapPairs := make(map[string]string)
843843

844+
// If a previous session ID has been set to link this new one to, we
845+
// first check if we have the referenced session, and we make sure it
846+
// has been revoked.
847+
var (
848+
linkedGroupID *session.ID
849+
linkedGroupSession *session.Session
850+
privDB firewalldb.PrivacyMapDB
851+
)
852+
if len(req.LinkedGroupId) != 0 {
853+
var groupID session.ID
854+
copy(groupID[:], req.LinkedGroupId)
855+
856+
// Check that the group actually does exist.
857+
groupSess, err := s.cfg.db.GetSessionByID(groupID)
858+
if err != nil {
859+
return nil, err
860+
}
861+
862+
// Ensure that the linked session is in fact the first session
863+
// in its group.
864+
if groupSess.ID != groupSess.GroupID {
865+
return nil, fmt.Errorf("can not link to session "+
866+
"%x since it is not the first in the session "+
867+
"group %x", groupSess.ID, groupSess.GroupID)
868+
}
869+
870+
// Now we need to check that all the sessions in the group are
871+
// no longer active.
872+
ok, err := s.cfg.db.CheckSessionGroupPredicate(
873+
groupID, func(s *session.Session) bool {
874+
return s.State == session.StateRevoked ||
875+
s.State == session.StateExpired
876+
},
877+
)
878+
if err != nil {
879+
return nil, err
880+
}
881+
882+
if !ok {
883+
return nil, fmt.Errorf("a linked session in group "+
884+
"%x is still active", groupID)
885+
}
886+
887+
linkedGroupID = &groupID
888+
linkedGroupSession = groupSess
889+
890+
privDB = s.cfg.privMap(groupID)
891+
}
892+
844893
// First need to fetch all the perms that need to be baked into this
845894
// mac based on the features.
846895
allFeatures, err := s.cfg.autopilot.ListFeatures(ctx)
@@ -883,7 +932,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
883932
if privacy {
884933
var privMapPairs map[string]string
885934
v, privMapPairs, err = v.RealToPseudo(
886-
nil,
935+
privDB,
887936
)
888937
if err != nil {
889938
return nil, err
@@ -1014,52 +1063,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10141063
caveats = append(caveats, firewall.MetaPrivacyCaveat)
10151064
}
10161065

1017-
// If a previous session ID has been set to link this new one to, we
1018-
// first check if we have the referenced session, and we make sure it
1019-
// has been revoked.
1020-
var (
1021-
linkedGroupID *session.ID
1022-
linkedGroupSession *session.Session
1023-
)
1024-
if len(req.LinkedGroupId) != 0 {
1025-
var groupID session.ID
1026-
copy(groupID[:], req.LinkedGroupId)
1027-
1028-
// Check that the group actually does exist.
1029-
groupSess, err := s.cfg.db.GetSessionByID(groupID)
1030-
if err != nil {
1031-
return nil, err
1032-
}
1033-
1034-
// Ensure that the linked session is in fact the first session
1035-
// in its group.
1036-
if groupSess.ID != groupSess.GroupID {
1037-
return nil, fmt.Errorf("can not link to session "+
1038-
"%x since it is not the first in the session "+
1039-
"group %x", groupSess.ID, groupSess.GroupID)
1040-
}
1041-
1042-
// Now we need to check that all the sessions in the group are
1043-
// no longer active.
1044-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
1045-
groupID, func(s *session.Session) bool {
1046-
return s.State == session.StateRevoked ||
1047-
s.State == session.StateExpired
1048-
},
1049-
)
1050-
if err != nil {
1051-
return nil, err
1052-
}
1053-
1054-
if !ok {
1055-
return nil, fmt.Errorf("a linked session in group "+
1056-
"%x is still active", groupID)
1057-
}
1058-
1059-
linkedGroupID = &groupID
1060-
linkedGroupSession = groupSess
1061-
}
1062-
10631066
s.sessRegMu.Lock()
10641067
defer s.sessRegMu.Unlock()
10651068

@@ -1096,7 +1099,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10961099
}
10971100

10981101
// Register all the privacy map pairs for this session ID.
1099-
privDB := s.cfg.privMap(sess.GroupID)
1102+
privDB = s.cfg.privMap(sess.GroupID)
11001103
err = privDB.Update(func(tx firewalldb.PrivacyMapTx) error {
11011104
for r, p := range privacyMapPairs {
11021105
err := tx.NewPair(r, p)

0 commit comments

Comments
 (0)