【AI】Postgres + Drizzle + Embeddingで意味検索する

こんにちは、フリーランスエンジニアの太田雅昭です。

構成

今回、下記を使用します。

  • Postgress
  • Drizzle
  • Open AI API

初めはPrismaで頑張っていたのですが、自由度が低く厳しそうでしたので、Drizzleに乗り換えた次第です。Drizzleならインデックス含めサクッとできました。感激。

Embedding

Embeddingは、最近流行りのAI技術です。文章をベクトル化し、意味合いのマッチ具合を判定できるようになります。

実装とテスト

準備

環境変数を設定します。

DATABASE_URL="postgresql://postgres:@localhost:5432/mydb"
OPENAI_API_KEY="sk-proj-xxx"

定数を定義します。

export const EMBEDDING_DIMENSIONS = 1536;
export const EMBEDDING_MODEL = "text-embedding-3-small";

スキーマを作ります。HNSWインデックス、cosineを使用しています。

import { index, integer, pgTable, varchar, vector } from "drizzle-orm/pg-core";
import { EMBEDDING_DIMENSIONS } from '../constants';

export const postsTable = pgTable(
  "posts",
  {
    id: integer().primaryKey().generatedAlwaysAsIdentity(),
    content: varchar({ length: 255 }).notNull(),
    embedding: vector({ dimensions: EMBEDDING_DIMENSIONS }).notNull(),
  },
  (table) => [
    index('embedding_hnsw_index').using('hnsw', table.embedding.op('vector_cosine_ops')),
  ]
);

dbをどこでも使用できるようにします。

import { drizzle } from 'drizzle-orm/node-postgres';

export const db = drizzle(process.env.DATABASE_URL!);

APIをラップします。

import OpenAI from "openai";
import { EMBEDDING_DIMENSIONS, EMBEDDING_MODEL } from "./constants";

const client = new OpenAI();

export async function embed(content: string) {
  const response = await client.embeddings.create({
    model: EMBEDDING_MODEL,
    input: content,
    dimensions: EMBEDDING_DIMENSIONS,
  });
  return response.data[0].embedding;
}

データを入れる

今回下記のようなデータを使用しました。

import 'dotenv/config';
import { db } from './db';
import { postsTable } from "./db/schema";
import { embed } from './openai';

const TEST_DATA = [
  'みかんを食べている男の人',
  'レストランで食事する家族連れ',
  'ギターを担いだ男二人がバーで飲んでいる',
  '猫をなでる子供',
  '散歩をするおじいさん',
]

async function insert() {
  for (const testData of TEST_DATA) {
    console.log(`Inserting ${testData}`);
    const embedding = await embed(testData);
    await db.insert(postsTable).values({
      content: testData,
      embedding,
    });
  }
}

insert();

検索する

下記のような検索コードを作りました。

import 'dotenv/config';
import { db } from './db';
import { postsTable } from "./db/schema";
import { embed } from './openai';
import { sql, cosineDistance } from 'drizzle-orm';


async function main() {
  const query = process.argv[2];
  if (!query) throw new Error('no query');
  const embedding = await embed(query);
  const result = await db
    .select({
      content: postsTable.content,
      distance: cosineDistance(postsTable.embedding, embedding)
    })
    .from(postsTable)
    .orderBy(cosineDistance(postsTable.embedding, embedding));
  console.log(result);
}

main();

実際に検索してみます。

tsx src/main.ts 音楽    

[
  { content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.6905424367734843 },
  { content: 'みかんを食べている男の人', distance: 0.7711687249641749 },
  { content: 'レストランで食事する家族連れ', distance: 0.8035564848319032 },
  { content: '散歩をするおじいさん', distance: 0.8601853937186783 },
  { content: '猫をなでる子供', distance: 0.863712573629618 }
]
tsx src/main.ts お年寄り

[
  { content: '散歩をするおじいさん', distance: 0.6252080873274957 },
  { content: '猫をなでる子供', distance: 0.7478481741193026 },
  { content: 'レストランで食事する家族連れ', distance: 0.7724779558841293 },
  { content: 'みかんを食べている男の人', distance: 0.7884846037445112 },
  { content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.8039060772629403 }
]
tsx src/main.ts cat 

[
  { content: '猫をなでる子供', distance: 0.5634503482552038 },
  { content: 'みかんを食べている男の人', distance: 0.7431684418126188 },
  { content: '散歩をするおじいさん', distance: 0.7994899092432515 },
  { content: 'ギターを担いだ男二人がバーで飲んでいる', distance: 0.8381289147215432 },
  { content: 'レストランで食事する家族連れ', distance: 0.8834894558445278 }
]

うまい具合に、存在しない検索語句でもちゃんとdistanceが反映されています。言語が違っていても、大丈夫そうです。distanceの閾値は、0.7あたりが良さそうですね。

ただお年寄りを検索しているのに、子供が割と上位に来ています。猫は年寄りといったイメージがあるのでしょうか。あるいはモデルに、shortを使用しているからかもしれません。この辺りは、速度・費用・精度のトレードオフですね。