001    package org.apache.torque.task;
002    
003    /*
004     * Licensed to the Apache Software Foundation (ASF) under one
005     * or more contributor license agreements.  See the NOTICE file
006     * distributed with this work for additional information
007     * regarding copyright ownership.  The ASF licenses this file
008     * to you under the Apache License, Version 2.0 (the
009     * "License"); you may not use this file except in compliance
010     * with the License.  You may obtain a copy of the License at
011     *
012     *   http://www.apache.org/licenses/LICENSE-2.0
013     *
014     * Unless required by applicable law or agreed to in writing,
015     * software distributed under the License is distributed on an
016     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
017     * KIND, either express or implied.  See the License for the
018     * specific language governing permissions and limitations
019     * under the License.
020     */
021    
022    import java.io.BufferedOutputStream;
023    import java.io.BufferedReader;
024    import java.io.File;
025    import java.io.FileInputStream;
026    import java.io.FileOutputStream;
027    import java.io.FileReader;
028    import java.io.IOException;
029    import java.io.InputStreamReader;
030    import java.io.PrintStream;
031    import java.io.Reader;
032    import java.io.StringReader;
033    import java.sql.Connection;
034    import java.sql.DatabaseMetaData;
035    import java.sql.Driver;
036    import java.sql.ResultSet;
037    import java.sql.ResultSetMetaData;
038    import java.sql.SQLException;
039    import java.sql.SQLWarning;
040    import java.sql.Statement;
041    import java.util.ArrayList;
042    import java.util.HashMap;
043    import java.util.Iterator;
044    import java.util.List;
045    import java.util.Map;
046    import java.util.Properties;
047    
048    import org.apache.commons.lang.StringUtils;
049    import org.apache.tools.ant.AntClassLoader;
050    import org.apache.tools.ant.BuildException;
051    import org.apache.tools.ant.Project;
052    import org.apache.tools.ant.PropertyHelper;
053    import org.apache.tools.ant.Task;
054    import org.apache.tools.ant.types.EnumeratedAttribute;
055    import org.apache.tools.ant.types.Path;
056    import org.apache.tools.ant.types.Reference;
057    
058    /**
059     * This task uses an SQL -> Database map in the form of a properties file to insert each SQL file listed into its
060     * designated database.
061     * 
062     * @author <a href="mailto:jeff@custommonkey.org">Jeff Martin</a>
063     * @author <a href="mailto:gholam@xtra.co.nz">Michael McCallum</A>
064     * @author <a href="mailto:tim.stephenson@sybase.com">Tim Stephenson</A>
065     * @author <a href="mailto:jvanzyl@apache.org">Jason van Zyl</A>
066     * @author <a href="mailto:mpoeschl@marmot.at">Martin Poeschl</a>
067     * @version $Id: TorqueSQLExec.java,v 1.1 2007-10-21 07:57:26 abyrne Exp $
068     */
069    public class TorqueSQLExec extends Task {
070            private int goodSql = 0;
071            private int totalSql = 0;
072            private Path classpath;
073            private AntClassLoader loader;
074    
075            /**
076         *
077         */
078            public static class DelimiterType extends EnumeratedAttribute {
079                    public static final String NORMAL = "normal";
080                    public static final String ROW = "row";
081    
082                    public String[] getValues() {
083                            return new String[] { NORMAL, ROW };
084                    }
085            }
086    
087            /** Database connection */
088            private Connection conn = null;
089    
090            /** Autocommit flag. Default value is false */
091            private boolean autocommit = false;
092    
093            /** SQL statement */
094            private Statement statement = null;
095    
096            /** DB driver. */
097            private String driver = null;
098    
099            /** DB url. */
100            private String url = null;
101    
102            /** User name. */
103            private String userId = null;
104    
105            /** Password */
106            private String password = null;
107    
108            /** SQL Statement delimiter */
109            private String delimiter = ";";
110    
111            /**
112             * The delimiter type indicating whether the delimiter will only be recognized on a line by itself
113             */
114            private String delimiterType = DelimiterType.NORMAL;
115    
116            /** Print SQL results. */
117            private boolean print = false;
118    
119            /** Print header columns. */
120            private boolean showheaders = true;
121    
122            /** Results Output file. */
123            private File output = null;
124    
125            /** RDBMS Product needed for this SQL. */
126            private String rdbms = null;
127    
128            /** RDBMS Version needed for this SQL. */
129            private String version = null;
130    
131            /** Action to perform if an error is found */
132            private String onError = "abort";
133    
134            /** Encoding to use when reading SQL statements from a file */
135            private String encoding = null;
136    
137            /** Src directory for the files listed in the sqldbmap. */
138            private String srcDir;
139    
140            /** Properties file that maps an individual SQL file to a database. */
141            private File sqldbmap;
142    
143            /**
144             * Set the sqldbmap properties file.
145             * 
146             * @param sqldbmap
147             *            filename for the sqldbmap
148             */
149            public void setSqlDbMap(String sqldbmap) {
150                    this.sqldbmap = getProject().resolveFile(sqldbmap);
151            }
152    
153            /**
154             * Get the sqldbmap properties file.
155             * 
156             * @return filename for the sqldbmap
157             */
158            public File getSqlDbMap() {
159                    return sqldbmap;
160            }
161    
162            /**
163             * Set the src directory for the sql files listed in the sqldbmap file.
164             * 
165             * @param srcDir
166             *            sql source directory
167             */
168            public void setSrcDir(String srcDir) {
169                    this.srcDir = getProject().resolveFile(srcDir).toString();
170            }
171    
172            /**
173             * Get the src directory for the sql files listed in the sqldbmap file.
174             * 
175             * @return sql source directory
176             */
177            public String getSrcDir() {
178                    return srcDir;
179            }
180    
181            /**
182             * Set the classpath for loading the driver.
183             * 
184             * @param classpath
185             *            the classpath
186             */
187            public void setClasspath(Path classpath) {
188                    if (this.classpath == null) {
189                            this.classpath = classpath;
190                    } else {
191                            this.classpath.append(classpath);
192                    }
193            }
194    
195            /**
196             * Create the classpath for loading the driver.
197             * 
198             * @return the classpath
199             */
200            public Path createClasspath() {
201                    if (this.classpath == null) {
202                            this.classpath = new Path(getProject());
203                    }
204                    return this.classpath.createPath();
205            }
206    
207            /**
208             * Set the classpath for loading the driver using the classpath reference.
209             * 
210             * @param r
211             *            reference to the classpath
212             */
213            public void setClasspathRef(Reference r) {
214                    createClasspath().setRefid(r);
215            }
216    
217            /**
218             * Set the sql command to execute
219             * 
220             * @param sql
221             *            sql command to execute
222             * @deprecated This method has no effect and will be removed in a future version.
223             */
224            public void addText(String sql) {
225            }
226    
227            /**
228             * Set the JDBC driver to be used.
229             * 
230             * @param driver
231             *            driver class name
232             */
233            public void setDriver(String driver) {
234                    this.driver = driver;
235            }
236    
237            /**
238             * Set the DB connection url.
239             * 
240             * @param url
241             *            connection url
242             */
243            public void setUrl(String url) {
244                    this.url = url;
245            }
246    
247            /**
248             * Set the user name for the DB connection.
249             * 
250             * @param userId
251             *            database user
252             */
253            public void setUserid(String userId) {
254                    this.userId = userId;
255            }
256    
257            /**
258             * Set the file encoding to use on the sql files read in
259             * 
260             * @param encoding
261             *            the encoding to use on the files
262             */
263            public void setEncoding(String encoding) {
264                    this.encoding = encoding;
265            }
266    
267            /**
268             * Set the password for the DB connection.
269             * 
270             * @param password
271             *            database password
272             */
273            public void setPassword(String password) {
274                    this.password = password;
275            }
276    
277            /**
278             * Set the autocommit flag for the DB connection.
279             * 
280             * @param autocommit
281             *            the autocommit flag
282             */
283            public void setAutocommit(boolean autocommit) {
284                    this.autocommit = autocommit;
285            }
286    
287            /**
288             * Set the statement delimiter.
289             * 
290             * <p>
291             * For example, set this to "go" and delimitertype to "ROW" for Sybase ASE or MS SQL Server.
292             * </p>
293             * 
294             * @param delimiter
295             */
296            public void setDelimiter(String delimiter) {
297                    this.delimiter = delimiter;
298            }
299    
300            /**
301             * Set the Delimiter type for this sql task. The delimiter type takes two values - normal and row. Normal means that
302             * any occurence of the delimiter terminate the SQL command whereas with row, only a line containing just the
303             * delimiter is recognized as the end of the command.
304             * 
305             * @param delimiterType
306             */
307            public void setDelimiterType(DelimiterType delimiterType) {
308                    this.delimiterType = delimiterType.getValue();
309            }
310    
311            /**
312             * Set the print flag.
313             * 
314             * @param print
315             */
316            public void setPrint(boolean print) {
317                    this.print = print;
318            }
319    
320            /**
321             * Set the showheaders flag.
322             * 
323             * @param showheaders
324             */
325            public void setShowheaders(boolean showheaders) {
326                    this.showheaders = showheaders;
327            }
328    
329            /**
330             * Set the output file.
331             * 
332             * @param output
333             */
334            public void setOutput(File output) {
335                    this.output = output;
336            }
337    
338            /**
339             * Set the rdbms required
340             * 
341             * @param vendor
342             */
343            public void setRdbms(String vendor) {
344                    this.rdbms = vendor.toLowerCase();
345            }
346    
347            /**
348             * Set the version required
349             * 
350             * @param version
351             */
352            public void setVersion(String version) {
353                    this.version = version.toLowerCase();
354            }
355    
356            /**
357             * Set the action to perform onerror
358             * 
359             * @param action
360             */
361            public void setOnerror(OnError action) {
362                    this.onError = action.getValue();
363            }
364    
365            /**
366             * Load the sql file and then execute it
367             * 
368             * @throws BuildException
369             */
370            @SuppressWarnings("unchecked")
371            public void execute() throws BuildException {
372                    if (sqldbmap == null || getSqlDbMap().exists() == false) {
373                            throw new BuildException("You haven't provided an sqldbmap, or " + "the one you specified doesn't exist: " + sqldbmap);
374                    }
375    
376                    if (driver == null) {
377                            throw new BuildException("Driver attribute must be set!", getLocation());
378                    }
379                    if (userId == null) {
380                            throw new BuildException("User Id attribute must be set!", getLocation());
381                    }
382                    if (password == null) {
383                            throw new BuildException("Password attribute must be set!", getLocation());
384                    }
385                    if (url == null) {
386                            throw new BuildException("Url attribute must be set!", getLocation());
387                    }
388    
389                    Properties map = new Properties();
390    
391                    try {
392                            FileInputStream fis = new FileInputStream(getSqlDbMap());
393                            map.load(fis);
394                            fis.close();
395                    } catch (IOException ioe) {
396                            throw new BuildException("Cannot open and process the sqldbmap!");
397                    }
398    
399                    Map<Object, Object> databases = new HashMap<Object, Object>();
400    
401                    Iterator<?> eachFileName = map.keySet().iterator();
402                    while (eachFileName.hasNext()) {
403                            String sqlfile = (String) eachFileName.next();
404                            String database = map.getProperty(sqlfile);
405    
406                            List<Object> files = (List<Object>) databases.get(database);
407    
408                            if (files == null) {
409                                    files = new ArrayList<Object>();
410                                    databases.put(database, files);
411                            }
412    
413                            // We want to make sure that the base schemas
414                            // are inserted first.
415                            if (sqlfile.indexOf("schema.sql") != -1) {
416                                    files.add(0, sqlfile);
417                            } else {
418                                    files.add(sqlfile);
419                            }
420                    }
421    
422                    Iterator<?> eachDatabase = databases.keySet().iterator();
423                    while (eachDatabase.hasNext()) {
424                            String db = (String) eachDatabase.next();
425                            List<Object> transactions = new ArrayList<Object>();
426                            eachFileName = ((List<?>) databases.get(db)).iterator();
427                            while (eachFileName.hasNext()) {
428                                    String fileName = (String) eachFileName.next();
429                                    File file = new File(srcDir, fileName);
430    
431                                    if (file.exists()) {
432                                            Transaction transaction = new Transaction();
433                                            transaction.setSrc(file);
434                                            transactions.add(transaction);
435                                    } else {
436                                            System.out.println("File '" + file.getAbsolutePath() + "' in sqldbmap does not exist, so skipping it.");
437                                    }
438                            }
439    
440                            insertDatabaseSqlFiles(url, db, transactions);
441                    }
442            }
443    
444            /**
445             * Take the base url, the target database and insert a set of SQL files into the target database.
446             * 
447             * @param url
448             * @param database
449             * @param transactions
450             */
451            private void insertDatabaseSqlFiles(String url, String database, List<?> transactions) {
452                    url = StringUtils.replace(url, "@DB@", database);
453                    System.out.println("Our new url -> " + url);
454    
455                    Driver driverInstance = null;
456                    try {
457                            Class<?> dc;
458                            if (classpath != null) {
459                                    log("Loading " + driver + " using AntClassLoader with classpath " + classpath, Project.MSG_VERBOSE);
460    
461                                    loader = new AntClassLoader(getProject(), classpath);
462                                    dc = loader.loadClass(driver);
463                            } else {
464                                    log("Loading " + driver + " using system loader.", Project.MSG_VERBOSE);
465                                    dc = Class.forName(driver);
466                            }
467                            driverInstance = (Driver) dc.newInstance();
468                    } catch (ClassNotFoundException e) {
469                            throw new BuildException("Class Not Found: JDBC driver " + driver + " could not be loaded", getLocation());
470                    } catch (IllegalAccessException e) {
471                            throw new BuildException("Illegal Access: JDBC driver " + driver + " could not be loaded", getLocation());
472                    } catch (InstantiationException e) {
473                            throw new BuildException("Instantiation Exception: JDBC driver " + driver + " could not be loaded", getLocation());
474                    }
475    
476                    try {
477                            log("connecting to " + url, Project.MSG_VERBOSE);
478                            Properties info = new Properties();
479                            info.put("user", userId);
480                            info.put("password", password);
481                            conn = driverInstance.connect(url, info);
482    
483                            if (conn == null) {
484                                    // Driver doesn't understand the URL
485                                    throw new SQLException("No suitable Driver for " + url);
486                            }
487    
488                            if (!isValidRdbms(conn)) {
489                                    return;
490                            }
491    
492                            conn.setAutoCommit(autocommit);
493                            statement = conn.createStatement();
494                            PrintStream out = System.out;
495                            try {
496                                    if (output != null) {
497                                            log("Opening PrintStream to output file " + output, Project.MSG_VERBOSE);
498                                            out = new PrintStream(new BufferedOutputStream(new FileOutputStream(output)));
499                                    }
500    
501                                    // Process all transactions
502                                    for (Iterator<?> it = transactions.iterator(); it.hasNext();) {
503                                            Transaction transaction = (Transaction) it.next();
504                                            transaction.runTransaction(out);
505                                            if (!autocommit) {
506                                                    log("Commiting transaction", Project.MSG_VERBOSE);
507                                                    conn.commit();
508                                            }
509                                    }
510                            } finally {
511                                    if (out != null && out != System.out) {
512                                            out.close();
513                                    }
514                            }
515                    } catch (IOException e) {
516                            if (!autocommit && conn != null && onError.equals("abort")) {
517                                    try {
518                                            conn.rollback();
519                                    } catch (SQLException ex) {
520                                            // do nothing.
521                                    }
522                            }
523                            throw new BuildException(e, getLocation());
524                    } catch (SQLException e) {
525                            if (!autocommit && conn != null && onError.equals("abort")) {
526                                    try {
527                                            conn.rollback();
528                                    } catch (SQLException ex) {
529                                            // do nothing.
530                                    }
531                            }
532                            throw new BuildException(e, getLocation());
533                    } finally {
534                            try {
535                                    if (statement != null) {
536                                            statement.close();
537                                    }
538                                    if (conn != null) {
539                                            conn.close();
540                                    }
541                            } catch (SQLException e) {
542                            }
543                    }
544    
545                    System.out.println(goodSql + " of " + totalSql + " SQL statements executed successfully");
546            }
547    
548            /**
549             * Read the statements from the .sql file and execute them. Lines starting with '//', '--' or 'REM ' are ignored.
550             * 
551             * @param reader
552             * @param out
553             * @throws SQLException
554             * @throws IOException
555             */
556            protected void runStatements(Reader reader, PrintStream out) throws SQLException, IOException {
557                    String sql = "";
558                    String line = "";
559    
560                    BufferedReader in = new BufferedReader(reader);
561                    PropertyHelper ph = PropertyHelper.getPropertyHelper(getProject());
562    
563                    try {
564                            while ((line = in.readLine()) != null) {
565                                    line = line.trim();
566                                    line = ph.replaceProperties("", line, getProject().getProperties());
567                                    if (line.startsWith("//") || line.startsWith("--")) {
568                                            continue;
569                                    }
570                                    if (line.length() > 4 && line.substring(0, 4).equalsIgnoreCase("REM ")) {
571                                            continue;
572                                    }
573    
574                                    sql += " " + line;
575                                    sql = sql.trim();
576    
577                                    // SQL defines "--" as a comment to EOL
578                                    // and in Oracle it may contain a hint
579                                    // so we cannot just remove it, instead we must end it
580                                    if (line.indexOf("--") >= 0) {
581                                            sql += "\n";
582                                    }
583    
584                                    if (delimiterType.equals(DelimiterType.NORMAL) && sql.endsWith(delimiter) || delimiterType.equals(DelimiterType.ROW) && line.equals(delimiter)) {
585                                            log("SQL: " + sql, Project.MSG_VERBOSE);
586                                            execSQL(sql.substring(0, sql.length() - delimiter.length()), out);
587                                            sql = "";
588                                    }
589                            }
590    
591                            // Catch any statements not followed by ;
592                            if (!sql.equals("")) {
593                                    execSQL(sql, out);
594                            }
595                    } catch (SQLException e) {
596                            throw e;
597                    }
598            }
599    
600            /**
601             * Verify if connected to the correct RDBMS
602             * 
603             * @param conn
604             */
605            protected boolean isValidRdbms(Connection conn) {
606                    if (rdbms == null && version == null) {
607                            return true;
608                    }
609    
610                    try {
611                            DatabaseMetaData dmd = conn.getMetaData();
612    
613                            if (rdbms != null) {
614                                    String theVendor = dmd.getDatabaseProductName().toLowerCase();
615    
616                                    log("RDBMS = " + theVendor, Project.MSG_VERBOSE);
617                                    if (theVendor == null || theVendor.indexOf(rdbms) < 0) {
618                                            log("Not the required RDBMS: " + rdbms, Project.MSG_VERBOSE);
619                                            return false;
620                                    }
621                            }
622    
623                            if (version != null) {
624                                    String theVersion = dmd.getDatabaseProductVersion().toLowerCase();
625    
626                                    log("Version = " + theVersion, Project.MSG_VERBOSE);
627                                    if (theVersion == null || !(theVersion.startsWith(version) || theVersion.indexOf(" " + version) >= 0)) {
628                                            log("Not the required version: \"" + version + "\"", Project.MSG_VERBOSE);
629                                            return false;
630                                    }
631                            }
632                    } catch (SQLException e) {
633                            // Could not get the required information
634                            log("Failed to obtain required RDBMS information", Project.MSG_ERR);
635                            return false;
636                    }
637    
638                    return true;
639            }
640    
641            /**
642             * Exec the sql statement.
643             * 
644             * @param sql
645             * @param out
646             * @throws SQLException
647             */
648            protected void execSQL(String sql, PrintStream out) throws SQLException {
649                    // Check and ignore empty statements
650                    if ("".equals(sql.trim())) {
651                            return;
652                    }
653    
654                    try {
655                            totalSql++;
656                            if (!statement.execute(sql)) {
657                                    log(statement.getUpdateCount() + " rows affected", Project.MSG_VERBOSE);
658                            } else {
659                                    if (print) {
660                                            printResults(out);
661                                    }
662                            }
663    
664                            SQLWarning warning = conn.getWarnings();
665                            while (warning != null) {
666                                    log(warning + " sql warning", Project.MSG_VERBOSE);
667                                    warning = warning.getNextWarning();
668                            }
669                            conn.clearWarnings();
670                            goodSql++;
671                    } catch (SQLException e) {
672                            System.out.println("Failed to execute: " + sql);
673                            if (!onError.equals("continue")) {
674                                    throw e;
675                            }
676                            log(e.toString(), Project.MSG_ERR);
677                    }
678            }
679    
680            /**
681             * print any results in the statement.
682             * 
683             * @param out
684             * @throws SQLException
685             */
686            protected void printResults(PrintStream out) throws java.sql.SQLException {
687                    ResultSet rs = null;
688                    do {
689                            rs = statement.getResultSet();
690                            if (rs != null) {
691                                    log("Processing new result set.", Project.MSG_VERBOSE);
692                                    ResultSetMetaData md = rs.getMetaData();
693                                    int columnCount = md.getColumnCount();
694                                    StringBuffer line = new StringBuffer();
695                                    if (showheaders) {
696                                            for (int col = 1; col < columnCount; col++) {
697                                                    line.append(md.getColumnName(col));
698                                                    line.append(",");
699                                            }
700                                            line.append(md.getColumnName(columnCount));
701                                            out.println(line);
702                                            line.setLength(0);
703                                    }
704                                    while (rs.next()) {
705                                            boolean first = true;
706                                            for (int col = 1; col <= columnCount; col++) {
707                                                    String columnValue = rs.getString(col);
708                                                    if (columnValue != null) {
709                                                            columnValue = columnValue.trim();
710                                                    }
711    
712                                                    if (first) {
713                                                            first = false;
714                                                    } else {
715                                                            line.append(",");
716                                                    }
717                                                    line.append(columnValue);
718                                            }
719                                            out.println(line);
720                                            line.setLength(0);
721                                    }
722                            }
723                    } while (statement.getMoreResults());
724                    out.println();
725            }
726    
727            /**
728             * Enumerated attribute with the values "continue", "stop" and "abort" for the onerror attribute.
729             */
730            public static class OnError extends EnumeratedAttribute {
731                    public static final String CONTINUE = "continue";
732    
733                    public static final String STOP = "stop";
734    
735                    public static final String ABORT = "abort";
736    
737                    public String[] getValues() {
738                            return new String[] { CONTINUE, STOP, ABORT };
739                    }
740            }
741    
742            /**
743             * Contains the definition of a new transaction element. Transactions allow several files or blocks of statements to
744             * be executed using the same JDBC connection and commit operation in between.
745             */
746            public class Transaction {
747                    private File tSrcFile = null;
748                    private String tSqlCommand = "";
749    
750                    public void setSrc(File src) {
751                            this.tSrcFile = src;
752                    }
753    
754                    public void addText(String sql) {
755                            this.tSqlCommand += sql;
756                    }
757    
758                    private void runTransaction(PrintStream out) throws IOException, SQLException {
759                            if (tSqlCommand.length() != 0) {
760                                    log("Executing commands", Project.MSG_INFO);
761                                    runStatements(new StringReader(tSqlCommand), out);
762                            }
763    
764                            if (tSrcFile != null) {
765                                    System.out.println("Executing file: " + tSrcFile.getAbsolutePath());
766                                    Reader reader = (encoding == null) ? new FileReader(tSrcFile) : new InputStreamReader(new FileInputStream(tSrcFile), encoding);
767                                    runStatements(reader, out);
768                                    reader.close();
769                            }
770                    }
771            }
772    }