From 9d66d94780a2f1c42bae5c91530e778a9a186723 Mon Sep 17 00:00:00 2001 From: Mike McCormack Date: Sat, 26 Jun 2004 00:11:08 +0000 Subject: [PATCH] Implement queries by string value. --- dlls/msi/msipriv.h | 3 ++ dlls/msi/query.h | 8 ++++ dlls/msi/sql.y | 53 ++++++++++++---------- dlls/msi/string.c | 23 ++++++++-- dlls/msi/where.c | 108 +++++++++++++++++++++++++++++++++++++++------ 5 files changed, 156 insertions(+), 39 deletions(-) diff --git a/dlls/msi/msipriv.h b/dlls/msi/msipriv.h index 138d37479ef..25975ff79bb 100644 --- a/dlls/msi/msipriv.h +++ b/dlls/msi/msipriv.h @@ -191,6 +191,9 @@ extern VOID msi_destroy_stringtable( string_table *st ); extern UINT msi_string_count( string_table *st ); extern UINT msi_id_refcount( string_table *st, UINT i ); extern UINT msi_string_totalsize( string_table *st ); +extern UINT msi_strcmp( string_table *st, UINT lval, UINT rval, UINT *res ); +extern const char *msi_string_lookup_id( string_table *st, UINT id ); + UINT VIEW_find_column( MSIVIEW *view, LPWSTR name, UINT *n ); diff --git a/dlls/msi/query.h b/dlls/msi/query.h index 4c5fa3aa826..43dac657f9b 100644 --- a/dlls/msi/query.h +++ b/dlls/msi/query.h @@ -48,6 +48,13 @@ #define EXPR_IVAL 4 #define EXPR_SVAL 5 #define EXPR_UVAL 6 +#define EXPR_STRCMP 7 +#define EXPR_UTF8 8 + +struct sql_str { + LPCWSTR data; + INT len; +}; typedef struct _string_list { @@ -73,6 +80,7 @@ struct expr LPWSTR sval; LPWSTR column; UINT col_number; + char *utf8; } u; }; diff --git a/dlls/msi/sql.y b/dlls/msi/sql.y index 934b6479d89..8765409d167 100644 --- a/dlls/msi/sql.y +++ b/dlls/msi/sql.y @@ -48,7 +48,7 @@ typedef struct tag_SQL_input MSIVIEW **view; /* view structure for the resulting query */ } SQL_input; -static LPWSTR SQL_getstring( SQL_input *info ); +static LPWSTR SQL_getstring( struct sql_str *str ); static INT SQL_getint( SQL_input *sql ); static int SQL_lex( void *SQL_lval, SQL_input *info); @@ -61,9 +61,9 @@ static BOOL SQL_MarkPrimaryKeys( create_col_info *cols, string_list *keys); static struct expr * EXPR_complex( struct expr *l, UINT op, struct expr *r ); -static struct expr * EXPR_column( LPWSTR column ); -static struct expr * EXPR_ival( INT ival ); -static struct expr * EXPR_sval( LPWSTR string ); +static struct expr * EXPR_column( LPWSTR ); +static struct expr * EXPR_ival( struct sql_str *); +static struct expr * EXPR_sval( struct sql_str *); %} @@ -71,6 +71,7 @@ static struct expr * EXPR_sval( LPWSTR string ); %union { + struct sql_str str; LPWSTR string; string_list *column_list; value_list *val_list; @@ -92,8 +93,10 @@ static struct expr * EXPR_sval( LPWSTR string ); %token TK_GE TK_GLOB TK_GROUP TK_GT %token TK_HAVING TK_HOLD %token TK_IGNORE TK_ILLEGAL TK_IMMEDIATE TK_IN TK_INDEX TK_INITIALLY -%token TK_ID -%token TK_INSERT TK_INSTEAD TK_INT TK_INTEGER TK_INTERSECT TK_INTO TK_IS +%token TK_ID +%token TK_INSERT TK_INSTEAD TK_INT +%token TK_INTEGER +%token TK_INTERSECT TK_INTO TK_IS %token TK_ISNULL %token TK_JOIN TK_JOIN_KW %token TK_KEY @@ -106,7 +109,7 @@ static struct expr * EXPR_sval( LPWSTR string ); %token TK_RAISE TK_REFERENCES TK_REM TK_REPLACE TK_RESTRICT TK_ROLLBACK %token TK_ROW TK_RP TK_RSHIFT %token TK_SELECT TK_SEMI TK_SET TK_SHORT TK_SLASH TK_SPACE TK_STAR TK_STATEMENT -%token TK_STRING +%token TK_STRING %token TK_TABLE TK_TEMP TK_THEN TK_TRANSACTION TK_TRIGGER %token TK_UMINUS TK_UNCLOSED_STRING TK_UNION TK_UNIQUE %token TK_UPDATE TK_UPLUS TK_USING @@ -490,12 +493,11 @@ constlist: const_val: TK_INTEGER { - SQL_input* sql = (SQL_input*) info; - $$ = EXPR_ival( SQL_getint(sql) ); + $$ = EXPR_ival( &$1 ); } | TK_STRING { - $$ = EXPR_sval( $1 ); + $$ = EXPR_sval( &$1 ); } ; @@ -527,13 +529,11 @@ table: string_or_id: TK_ID { - SQL_input* sql = (SQL_input*) info; - $$ = SQL_getstring(sql); + $$ = SQL_getstring( &$1 ); } | TK_STRING { - SQL_input* sql = (SQL_input*) info; - $$ = SQL_getstring(sql); + $$ = SQL_getstring( &$1 ); } ; @@ -542,6 +542,7 @@ string_or_id: int SQL_lex( void *SQL_lval, SQL_input *sql) { int token; + struct sql_str * str = SQL_lval; do { @@ -553,6 +554,8 @@ int SQL_lex( void *SQL_lval, SQL_input *sql) sql->len = sqliteGetToken( &sql->command[sql->n], &token ); if( sql->len==0 ) break; + str->data = &sql->command[sql->n]; + str->len = sql->len; } while( token == TK_SPACE ); @@ -561,11 +564,11 @@ int SQL_lex( void *SQL_lval, SQL_input *sql) return token; } -LPWSTR SQL_getstring( SQL_input *sql ) +LPWSTR SQL_getstring( struct sql_str *strdata) { - LPCWSTR p = &sql->command[sql->n]; + LPCWSTR p = strdata->data; + UINT len = strdata->len; LPWSTR str; - UINT len = sql->len; /* if there's quotes, remove them */ if( (p[0]=='`') && (p[len-1]=='`') ) @@ -650,35 +653,35 @@ static struct expr * EXPR_complex( struct expr *l, UINT op, struct expr *r ) return e; } -static struct expr * EXPR_column( LPWSTR column ) +static struct expr * EXPR_column( LPWSTR str ) { struct expr *e = HeapAlloc( GetProcessHeap(), 0, sizeof *e ); if( e ) { e->type = EXPR_COLUMN; - e->u.column = column; + e->u.sval = str; } return e; } -static struct expr * EXPR_ival( INT ival ) +static struct expr * EXPR_ival( struct sql_str *str ) { struct expr *e = HeapAlloc( GetProcessHeap(), 0, sizeof *e ); if( e ) { e->type = EXPR_IVAL; - e->u.ival = ival; + e->u.ival = atoiW( str->data ); } return e; } -static struct expr * EXPR_sval( LPWSTR string ) +static struct expr * EXPR_sval( struct sql_str *str ) { struct expr *e = HeapAlloc( GetProcessHeap(), 0, sizeof *e ); if( e ) { e->type = EXPR_SVAL; - e->u.sval = string; + e->u.sval = SQL_getstring( str ); } return e; } @@ -692,6 +695,10 @@ void delete_expr( struct expr *e ) delete_expr( e->u.expr.left ); delete_expr( e->u.expr.right ); } + else if( e->type == EXPR_UTF8 ) + HeapFree( GetProcessHeap(), 0, e->u.utf8 ); + else if( e->type == EXPR_SVAL ) + HeapFree( GetProcessHeap(), 0, e->u.sval ); HeapFree( GetProcessHeap(), 0, e ); } diff --git a/dlls/msi/string.c b/dlls/msi/string.c index afed8234c0f..0e7cf77848c 100644 --- a/dlls/msi/string.c +++ b/dlls/msi/string.c @@ -208,7 +208,7 @@ int msi_addstringW( string_table *st, UINT n, const WCHAR *data, UINT len, UINT } /* find the string identified by an id - return null if there's none */ -static const char *string_lookup_id( string_table *st, UINT id ) +const char *msi_string_lookup_id( string_table *st, UINT id ) { if( id == 0 ) return ""; @@ -241,7 +241,7 @@ UINT msi_id2stringW( string_table *st, UINT id, LPWSTR buffer, UINT *sz ) TRACE("Finding string %d of %d\n", id, st->count); - str = string_lookup_id( st, id ); + str = msi_string_lookup_id( st, id ); if( !str ) return ERROR_FUNCTION_FAILED; @@ -277,7 +277,7 @@ UINT msi_id2stringA( string_table *st, UINT id, LPSTR buffer, UINT *sz ) TRACE("Finding string %d of %d\n", id, st->count); - str = string_lookup_id( st, id ); + str = msi_string_lookup_id( st, id ); if( !str ) return ERROR_FUNCTION_FAILED; @@ -353,6 +353,23 @@ UINT msi_string2id( string_table *st, LPCWSTR buffer, UINT *id ) return r; } +UINT msi_strcmp( string_table *st, UINT lval, UINT rval, UINT *res ) +{ + const char *l_str, *r_str; /* utf8 */ + + l_str = msi_string_lookup_id( st, lval ); + if( !l_str ) + return ERROR_INVALID_PARAMETER; + + r_str = msi_string_lookup_id( st, rval ); + if( !r_str ) + return ERROR_INVALID_PARAMETER; + + /* does this do the right thing for all UTF-8 strings? */ + *res = strcmp( l_str, r_str ); + + return ERROR_SUCCESS; +} UINT msi_string_count( string_table *st ) { diff --git a/dlls/msi/where.c b/dlls/msi/where.c index 981d2ee8d3a..c6337243561 100644 --- a/dlls/msi/where.c +++ b/dlls/msi/where.c @@ -95,7 +95,54 @@ static UINT INT_evaluate( UINT lval, UINT op, UINT rval ) return 0; } -static UINT WHERE_evaluate( MSIVIEW *table, UINT row, +static const char *STRING_evaluate( string_table *st, + MSIVIEW *table, UINT row, struct expr *expr ) +{ + UINT val = 0, r; + + switch( expr->type ) + { + case EXPR_COL_NUMBER: + r = table->ops->fetch_int( table, row, expr->u.col_number, &val ); + if( r != ERROR_SUCCESS ) + return NULL; + return msi_string_lookup_id( st, val ); + + case EXPR_UTF8: + return expr->u.utf8; + + default: + ERR("Invalid expression type\n"); + break; + } + return NULL; +} + +static UINT STRCMP_Evaluate( string_table *st, MSIVIEW *table, UINT row, + struct expr *cond, UINT *val ) +{ + int sr; + const char *l_str, *r_str; + + l_str = STRING_evaluate( st, table, row, cond->u.expr.left ); + r_str = STRING_evaluate( st, table, row, cond->u.expr.right ); + if( l_str == r_str ) + sr = 0; + else if( l_str && ! r_str ) + sr = 1; + else if( r_str && ! l_str ) + sr = -1; + else + sr = strcmp( l_str, r_str ); + + *val = ( cond->u.expr.op == OP_EQ && ( sr == 0 ) ) || + ( cond->u.expr.op == OP_LT && ( sr < 0 ) ) || + ( cond->u.expr.op == OP_GT && ( sr > 0 ) ); + + return ERROR_SUCCESS; +} + +static UINT WHERE_evaluate( MSIDATABASE *db, MSIVIEW *table, UINT row, struct expr *cond, UINT *val ) { UINT r, lval, rval; @@ -117,15 +164,18 @@ static UINT WHERE_evaluate( MSIVIEW *table, UINT row, return ERROR_SUCCESS; case EXPR_COMPLEX: - r = WHERE_evaluate( table, row, cond->u.expr.left, &lval ); + r = WHERE_evaluate( db, table, row, cond->u.expr.left, &lval ); if( r != ERROR_SUCCESS ) return r; - r = WHERE_evaluate( table, row, cond->u.expr.right, &rval ); + r = WHERE_evaluate( db, table, row, cond->u.expr.right, &rval ); if( r != ERROR_SUCCESS ) return r; *val = INT_evaluate( lval, cond->u.expr.op, rval ); return ERROR_SUCCESS; + case EXPR_STRCMP: + return STRCMP_Evaluate( db->strings, table, row, cond, val ); + default: ERR("Invalid expression type\n"); break; @@ -161,7 +211,7 @@ static UINT WHERE_execute( struct tagMSIVIEW *view, MSIHANDLE record ) for( i=0; icond, &val ); + r = WHERE_evaluate( wv->db, table, i, wv->cond, &val ); if( r != ERROR_SUCCESS ) return r; if( val ) @@ -295,20 +345,21 @@ UINT WHERE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table ) return ERROR_SUCCESS; } -static UINT WHERE_VerifyCondition( MSIVIEW *table, struct expr *cond, +static UINT WHERE_VerifyCondition( MSIDATABASE *db, MSIVIEW *table, struct expr *cond, UINT *valid ) { - UINT r, col = 0; + UINT r, val = 0, len; + char *str; switch( cond->type ) { case EXPR_COLUMN: - r = VIEW_find_column( table, cond->u.column, &col ); + r = VIEW_find_column( table, cond->u.column, &val ); if( r == ERROR_SUCCESS ) { *valid = 1; cond->type = EXPR_COL_NUMBER; - cond->u.col_number = col; + cond->u.col_number = val; } else { @@ -317,14 +368,35 @@ static UINT WHERE_VerifyCondition( MSIVIEW *table, struct expr *cond, } break; case EXPR_COMPLEX: - r = WHERE_VerifyCondition( table, cond->u.expr.left, valid ); + r = WHERE_VerifyCondition( db, table, cond->u.expr.left, valid ); if( r != ERROR_SUCCESS ) return r; if( !*valid ) return ERROR_SUCCESS; - r = WHERE_VerifyCondition( table, cond->u.expr.right, valid ); + r = WHERE_VerifyCondition( db, table, cond->u.expr.right, valid ); if( r != ERROR_SUCCESS ) return r; + + /* check the type of the comparison */ + if( ( cond->u.expr.left->type == EXPR_UTF8 ) || + ( cond->u.expr.right->type == EXPR_UTF8 ) ) + { + switch( cond->u.expr.op ) + { + case OP_EQ: + case OP_GT: + case OP_LT: + break; + default: + *valid = FALSE; + return ERROR_INVALID_PARAMETER; + } + + /* FIXME: check we're comparing a string to a column */ + + cond->type = EXPR_STRCMP; + } + break; case EXPR_IVAL: *valid = 1; @@ -332,8 +404,18 @@ static UINT WHERE_VerifyCondition( MSIVIEW *table, struct expr *cond, cond->u.uval = cond->u.ival + (1<<15); break; case EXPR_SVAL: - *valid = 0; - FIXME("can't deal with string values yet\n"); + /* convert to UTF8 so we have the same format as the DB */ + len = WideCharToMultiByte( CP_UTF8, 0, + cond->u.sval, -1, NULL, 0, NULL, NULL); + str = HeapAlloc( GetProcessHeap(), 0, len ); + if( !str ) + return ERROR_OUTOFMEMORY; + WideCharToMultiByte( CP_UTF8, 0, + cond->u.sval, -1, str, len, NULL, NULL); + HeapFree( GetProcessHeap(), 0, cond->u.sval ); + cond->type = EXPR_UTF8; + cond->u.utf8 = str; + *valid = 1; break; default: ERR("Invalid expression type\n"); @@ -359,7 +441,7 @@ UINT WHERE_AddCondition( MSIVIEW *view, struct expr *cond ) TRACE("Adding condition\n"); - r = WHERE_VerifyCondition( wv->table, cond, &valid ); + r = WHERE_VerifyCondition( wv->db, wv->table, cond, &valid ); if( r != ERROR_SUCCESS ) ERR("condition evaluation failed\n");