From abd1174941b7e074d3cd14c3a121febd6eed96b7 Mon Sep 17 00:00:00 2001 From: Bernhard Loos Date: Fri, 26 Aug 2011 04:53:39 +0200 Subject: [PATCH] msi: Protected primary keys against modification. --- dlls/msi/msipriv.h | 1 + dlls/msi/record.c | 50 ++++++++++++++++++++++++----------------- dlls/msi/where.c | 56 ++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 79 insertions(+), 28 deletions(-) diff --git a/dlls/msi/msipriv.h b/dlls/msi/msipriv.h index 2524752f0f9..998970baf30 100644 --- a/dlls/msi/msipriv.h +++ b/dlls/msi/msipriv.h @@ -809,6 +809,7 @@ extern UINT MSI_RecordSetStreamFromFileW( MSIRECORD *, UINT, LPCWSTR ) DECLSPEC_ extern UINT MSI_RecordCopyField( MSIRECORD *, UINT, MSIRECORD *, UINT ) DECLSPEC_HIDDEN; extern MSIRECORD *MSI_CloneRecord( MSIRECORD * ) DECLSPEC_HIDDEN; extern BOOL MSI_RecordsAreEqual( MSIRECORD *, MSIRECORD * ) DECLSPEC_HIDDEN; +extern BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field) DECLSPEC_HIDDEN; /* stream internals */ extern void enum_stream_names( IStorage *stg ) DECLSPEC_HIDDEN; diff --git a/dlls/msi/record.c b/dlls/msi/record.c index 0e4fb8a2624..7acbfc77b9a 100644 --- a/dlls/msi/record.c +++ b/dlls/msi/record.c @@ -994,6 +994,34 @@ MSIRECORD *MSI_CloneRecord(MSIRECORD *rec) return clone; } +BOOL MSI_RecordsAreFieldsEqual(MSIRECORD *a, MSIRECORD *b, UINT field) +{ + if (a->fields[field].type != b->fields[field].type) + return FALSE; + + switch (a->fields[field].type) + { + case MSIFIELD_NULL: + break; + + case MSIFIELD_INT: + if (a->fields[field].u.iVal != b->fields[field].u.iVal) + return FALSE; + break; + + case MSIFIELD_WSTR: + if (strcmpW(a->fields[field].u.szwVal, b->fields[field].u.szwVal)) + return FALSE; + break; + + case MSIFIELD_STREAM: + default: + return FALSE; + } + return TRUE; +} + + BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b) { UINT i; @@ -1003,28 +1031,8 @@ BOOL MSI_RecordsAreEqual(MSIRECORD *a, MSIRECORD *b) for (i = 0; i <= a->count; i++) { - if (a->fields[i].type != b->fields[i].type) + if (!MSI_RecordsAreFieldsEqual( a, b, i )) return FALSE; - - switch (a->fields[i].type) - { - case MSIFIELD_NULL: - break; - - case MSIFIELD_INT: - if (a->fields[i].u.iVal != b->fields[i].u.iVal) - return FALSE; - break; - - case MSIFIELD_WSTR: - if (strcmpW(a->fields[i].u.szwVal, b->fields[i].u.szwVal)) - return FALSE; - break; - - case MSIFIELD_STREAM: - default: - return FALSE; - } } return TRUE; diff --git a/dlls/msi/where.c b/dlls/msi/where.c index d8ac5e8c20a..07422448d39 100644 --- a/dlls/msi/where.c +++ b/dlls/msi/where.c @@ -262,9 +262,10 @@ static UINT WHERE_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec ) static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask ) { MSIWHEREVIEW *wv = (MSIWHEREVIEW*)view; - UINT r, offset = 0, reduced_mask = 0; + UINT i, r, offset = 0; JOINTABLE *table = wv->tables; UINT *rows; + UINT mask_copy = mask; TRACE("%p %d %p %08x\n", wv, row, rec, mask ); @@ -275,28 +276,54 @@ static UINT WHERE_set_row( struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UI if (r != ERROR_SUCCESS) return r; - if(mask >= 1 << wv->col_count) + if (mask >= 1 << wv->col_count) return ERROR_INVALID_PARAMETER; + do + { + for (i = 0; i < table->col_count; i++) { + UINT type; + + if (!(mask_copy & (1 << i))) + continue; + r = table->view->ops->get_column_info(table->view, i + 1, NULL, + &type, NULL, NULL ); + if (r != ERROR_SUCCESS) + return r; + if (type & MSITYPE_KEY) + return ERROR_FUNCTION_FAILED; + } + mask_copy >>= table->col_count; + } + while (mask_copy && (table = table->next)); + + table = wv->tables; + do { const UINT col_count = table->col_count; UINT i; MSIRECORD *reduced; + UINT reduced_mask = (mask >> offset) & ((1 << col_count) - 1); + + if (!reduced_mask) + { + offset += col_count; + continue; + } reduced = MSI_CreateRecord(col_count); if (!reduced) return ERROR_FUNCTION_FAILED; - for (i = 0; i < col_count; i++) + for (i = 1; i <= col_count; i++) { - r = MSI_RecordCopyField(rec, i + offset + 1, reduced, i + 1); + r = MSI_RecordCopyField(rec, i + offset, reduced, i); if (r != ERROR_SUCCESS) break; } offset += col_count; - reduced_mask = mask >> (wv->col_count - offset) & ((1 << col_count) - 1); if (r == ERROR_SUCCESS) r = table->view->ops->set_row(table->view, rows[table->table_index], reduced, reduced_mask); @@ -644,13 +671,28 @@ static UINT join_find_row( MSIWHEREVIEW *wv, MSIRECORD *rec, UINT *row ) static UINT join_modify_update( struct tagMSIVIEW *view, MSIRECORD *rec ) { MSIWHEREVIEW *wv = (MSIWHEREVIEW *)view; - UINT r, row; + UINT r, row, i, mask = 0; + MSIRECORD *current; + r = join_find_row( wv, rec, &row ); if (r != ERROR_SUCCESS) return r; - return WHERE_set_row( view, row, rec, (1 << wv->col_count) - 1 ); + r = msi_view_get_row( wv->db, view, row, ¤t ); + if (r != ERROR_SUCCESS) + return r; + + assert(MSI_RecordGetFieldCount(rec) == MSI_RecordGetFieldCount(current)); + + for (i = MSI_RecordGetFieldCount(rec); i > 0; i--) + { + if (!MSI_RecordsAreFieldsEqual(rec, current, i)) + mask |= 1 << (i - 1); + } + msiobj_release(¤t->hdr); + + return WHERE_set_row( view, row, rec, mask ); } static UINT WHERE_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode,