【AI】Postgres + Drizzle + Embeddingで意味検索する
こんにちは、フリーランスエンジニアの太田雅昭です。
構成
今回、下記を使用します。
- Postgress
- Drizzle
- Open AI API
初めはPrismaで頑張っていたのですが、自由度が低く厳しそうでしたので、Drizzleに乗り換えた次第です。Drizzleならインデックス含めサクッとできました。感激。
Embedding
Embeddingは、最近流行りのAI技術です。文章をベクトル化し、意味合いのマッチ具合を判定できるようになります。
実装とテスト
準備
環境変数を設定します。
DATABASE_URL="postgresql://postgres:@localhost:5432/mydb"
OPENAI_API_KEY="sk-proj-xxx"
定数を定義します。
export const EMBEDDING_DIMENSIONS = 1536;
export const EMBEDDING_MODEL = "text-embedding-3-small";
スキーマを作ります。HNSWインデックス、cosineを使用しています。
import { index, integer, pgTable, varchar, vector } from "drizzle-orm/pg-core";
import { EMBEDDING_DIMENSIONS } from '../constants';
export const postsTable = pgTable(
"posts",
{
id: integer().primaryKey().generatedAlwaysAsIdentity(),
content: varchar({ length: 255 }).notNull(),
embedding: vector({ dimensions: EMBEDDING_DIMENSIONS }).notNull(),
},
(table) => [
index('embedding_hnsw_index').using('hnsw', table.embedding.op('vector_cosine_ops')),
]
);
dbをどこでも使用できるようにします。
import { drizzle } from 'drizzle-orm/node-postgres';
export const db = drizzle(process.env.DATABASE_URL!);
APIをラップします。
import OpenAI from "openai";
import { EMBEDDING_DIMENSIONS, EMBEDDING_MODEL } from "./constants";
const client = new OpenAI();
export async function embed(content: string) {
const response = await client.embeddings.create({
model: EMBEDDING_MODEL,
input: content,
dimensions: EMBEDDING_DIMENSIONS,
});
return response.data[0].embedding;
}
データを入れる
今回下記のようなデータを使用しました。
import 'dotenv/config';
import { db } from './db';
import { postsTable } from "./db/schema";
import { embed } from './openai';
const TEST_DATA = [
'みかんを食べている男の人',
'レストランで食事する家族連れ',
'ギターを担いだ男二人がバーで飲んでいる',
'猫をなでる子供',
'散歩をするおじいさん',
]
async function insert() {
for (const testData of TEST_DATA) {
console.log(`Inserting ${testData}`);
const embedding = await embed(testData);
await db.insert(postsTable).values({
content: testData,
embedding,
});
}
}
insert();
検索する
下記のような検索コードを作りました。
import 'dotenv/config';
import { db } from './db';
import { postsTable } from "./db/schema";
import { embed } from './openai';
import { sql, cosineDistance } from 'drizzle-orm';
async function main() {
const query = process.argv[2];
if (!query) throw new Error('no query');
const embedding = await embed(query);
const result = await db
.select({
content: postsTable.content,
distance: cosineDistance(postsTable.embedding, embedding)
})
.from(postsTable)
.orderBy(cosineDistance(postsTable.embedding, embedding));
console.log(result);
}
main();
実際に検索してみます。
tsx src/main.ts 音楽
[
{ content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.6905424367734843 },
{ content: 'みかんを食べている男の人', distance: 0.7711687249641749 },
{ content: 'レストランで食事する家族連れ', distance: 0.8035564848319032 },
{ content: '散歩をするおじいさん', distance: 0.8601853937186783 },
{ content: '猫をなでる子供', distance: 0.863712573629618 }
]
tsx src/main.ts お年寄り
[
{ content: '散歩をするおじいさん', distance: 0.6252080873274957 },
{ content: '猫をなでる子供', distance: 0.7478481741193026 },
{ content: 'レストランで食事する家族連れ', distance: 0.7724779558841293 },
{ content: 'みかんを食べている男の人', distance: 0.7884846037445112 },
{ content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.8039060772629403 }
]
tsx src/main.ts cat
[
{ content: '猫をなでる子供', distance: 0.5634503482552038 },
{ content: 'みかんを食べている男の人', distance: 0.7431684418126188 },
{ content: '散歩をするおじいさん', distance: 0.7994899092432515 },
{ content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.8381289147215432 },
{ content: 'レストランで食事する家族連れ', distance: 0.8834894558445278 }
]
うまい具合に、存在しない検索語句でもちゃんとdistanceが反映されています。言語が違っていても、大丈夫そうです。distanceの閾値は、0.7あたりが良さそうですね。
ただお年寄りを検索しているのに、子供が割と上位に来ています。猫は年寄りといったイメージがあるのでしょうか。あるいはモデルに、shortを使用しているからかもしれません。この辺りは、速度・費用・精度のトレードオフですね。